[Flax/run_hybrid_clip] Fix duplicating images when captions_per_image exceeds the number of captions, enable truncation
This commit is contained in:
committed by
GitHub
parent
c1c2d68d37
commit
0a22335e66
@@ -224,8 +224,9 @@ class ImageTextDataset(VisionDataset):
|
||||
self.image_paths = []
|
||||
|
||||
for example in examples:
|
||||
self.captions.extend(example["captions"][:captions_per_image])
|
||||
self.image_paths.extend([example["image_path"]] * captions_per_image)
|
||||
captions_subset = example["captions"][:captions_per_image]
|
||||
self.captions.extend(captions_subset)
|
||||
self.image_paths.extend([example["image_path"]] * len(captions_subset))
|
||||
|
||||
def _load_image(self, idx: int):
|
||||
path = self.image_paths[idx]
|
||||
@@ -373,7 +374,9 @@ def main():
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
|
||||
captions = [example[1] for example in examples]
|
||||
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
|
||||
inputs = tokenizer(
|
||||
captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np"
|
||||
)
|
||||
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
|
||||
Reference in New Issue
Block a user