From 0a22335e660b74935a6eb85099933738f495d1ca Mon Sep 17 00:00:00 2001 From: Eduardo Gonzalez Ponferrada Date: Wed, 1 Sep 2021 22:49:49 -0700 Subject: [PATCH] [Flax/run_hybrid_clip] Fix duplicating images when captions_per_image exceeds the number of captions, enable truncation --- .../jax-projects/hybrid_clip/run_hybrid_clip.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index b9200e0b2b..8d2648811d 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -224,8 +224,9 @@ class ImageTextDataset(VisionDataset): self.image_paths = [] for example in examples: - self.captions.extend(example["captions"][:captions_per_image]) - self.image_paths.extend([example["image_path"]] * captions_per_image) + captions_subset = example["captions"][:captions_per_image] + self.captions.extend(captions_subset) + self.image_paths.extend([example["image_path"]] * len(captions_subset)) def _load_image(self, idx: int): path = self.image_paths[idx] @@ -373,7 +374,9 @@ def main(): def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy() captions = [example[1] for example in examples] - inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np") + inputs = tokenizer( + captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np" + ) batch = { "pixel_values": pixel_values,