[Flax/run_hybrid_clip] Fix duplicating images when captions_per_image exceeds the number of captions, enable truncation
This commit is contained in:
committed by
GitHub
parent
c1c2d68d37
commit
0a22335e66
@@ -224,8 +224,9 @@ class ImageTextDataset(VisionDataset):
|
|||||||
self.image_paths = []
|
self.image_paths = []
|
||||||
|
|
||||||
for example in examples:
|
for example in examples:
|
||||||
self.captions.extend(example["captions"][:captions_per_image])
|
captions_subset = example["captions"][:captions_per_image]
|
||||||
self.image_paths.extend([example["image_path"]] * captions_per_image)
|
self.captions.extend(captions_subset)
|
||||||
|
self.image_paths.extend([example["image_path"]] * len(captions_subset))
|
||||||
|
|
||||||
def _load_image(self, idx: int):
|
def _load_image(self, idx: int):
|
||||||
path = self.image_paths[idx]
|
path = self.image_paths[idx]
|
||||||
@@ -373,7 +374,9 @@ def main():
|
|||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
||||||
captions = [example[1] for example in examples]
|
captions = [example[1] for example in examples]
|
||||||
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
|
inputs = tokenizer(
|
||||||
|
captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np"
|
||||||
|
)
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
|
|||||||
Reference in New Issue
Block a user