From e1ad1886410512915725775af44989f23e8fd674 Mon Sep 17 00:00:00 2001 From: Peter Lin Date: Wed, 18 Jan 2023 06:40:24 -0800 Subject: [PATCH] Fix git model for generate with beam search. (#21071) * Fix git model for generate with beam search. * Update comment * Fix bug on multi batch * Add generate tests * Clean up tests * Fix style Co-authored-by: Niels Rogge --- src/transformers/models/git/modeling_git.py | 5 ++ tests/models/git/test_modeling_git.py | 94 +++++++++++---------- 2 files changed, 53 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index e2eee6e831..9bcaa220a0 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1264,6 +1264,11 @@ class GitModel(GitPreTrainedModel): device=embedding_output.device, ) + # Repeat visual features to match embedding batch size. + projected_visual_features = projected_visual_features.repeat( + embedding_output.size(0) // projected_visual_features.size(0), 1, 1 + ) + # concatenate patch token and text token embeddings hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1) diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 55ede22e8c..67bede12bd 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -21,6 +21,7 @@ from transformers import GitConfig, GitProcessor, GitVisionConfig, is_torch_avai from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -29,14 +30,7 @@ if is_torch_available(): import torch from torch import nn - from transformers import ( - MODEL_FOR_BACKBONE_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_MAPPING, - GitForCausalLM, - GitModel, - GitVisionModel, - ) + from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, GitForCausalLM, GitModel, GitVisionModel from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST @@ -259,13 +253,9 @@ class GitModelTester: pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) - token_labels = None - if self.use_labels: - token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels) - config = self.get_config() - return config, input_ids, input_mask, pixel_values, token_labels + return config, input_ids, input_mask, pixel_values def get_config(self): """ @@ -292,7 +282,7 @@ class GitModelTester: pad_token_id=self.pad_token_id, ) - def create_and_check_model(self, config, input_ids, input_mask, pixel_values, token_labels): + def create_and_check_model(self, config, input_ids, input_mask, pixel_values): model = GitModel(config=config) model.to(torch_device) model.eval() @@ -310,7 +300,7 @@ class GitModelTester: result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size) ) - def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values, token_labels): + def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values): model = GitForCausalLM(config=config) model.to(torch_device) model.eval() @@ -331,6 +321,24 @@ class GitModelTester: self.parent.assertEqual(result.loss.shape, ()) self.parent.assertTrue(result.loss.item() > 0) + def _test_beam_search_generate(self, config, input_ids, input_mask, pixel_values): + model = GitForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # generate + generated_ids = model.generate( + input_ids, + attention_mask=input_mask, + pixel_values=pixel_values, + do_sample=False, + max_length=20, + num_beams=2, + num_return_sequences=2, + ) + + self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -339,7 +347,6 @@ class GitModelTester: input_ids, input_mask, pixel_values, - token_labels, ) = config_and_inputs inputs_dict = { @@ -352,7 +359,7 @@ class GitModelTester: @require_torch -class GitModelTest(ModelTesterMixin, unittest.TestCase): +class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else () all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else () @@ -383,47 +390,42 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_beam_search_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester._test_beam_search_generate(*config_and_inputs) + def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_causal_lm(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) - - def test_training(self): - if not self.model_tester.is_training: - return - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - - if model_class in [ - *get_values(MODEL_MAPPING), - *get_values(MODEL_FOR_BACKBONE_MAPPING), - ]: - continue - - print("Model class:", model_class) - - model = model_class(config) - model.to(torch_device) - model.train() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - for k, v in inputs.items(): - print(k, v.shape) - loss = model(**inputs).loss - loss.backward() - @slow def test_model_from_pretrained(self): for model_name in GIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = GitModel.from_pretrained(model_name) self.assertIsNotNone(model) + @unittest.skip(reason="GIT has pixel values as additional input") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="GIT has pixel values as additional input") + def test_contrastive_generate(self): + pass + + @unittest.skip(reason="GIT has pixel values as additional input") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="GIT has pixel values as additional input") + def test_greedy_generate_dict_outputs_use_cache(self): + pass + @require_torch @require_vision