Fix git model for generate with beam search. (#21071)
* Fix git model for generate with beam search. * Update comment * Fix bug on multi batch * Add generate tests * Clean up tests * Fix style Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -21,6 +21,7 @@ from transformers import GitConfig, GitProcessor, GitVisionConfig, is_torch_avai
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
@@ -29,14 +30,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_BACKBONE_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
GitForCausalLM,
|
||||
GitModel,
|
||||
GitVisionModel,
|
||||
)
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, GitForCausalLM, GitModel, GitVisionModel
|
||||
from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -259,13 +253,9 @@ class GitModelTester:
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, pixel_values, token_labels
|
||||
return config, input_ids, input_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
"""
|
||||
@@ -292,7 +282,7 @@ class GitModelTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask, pixel_values, token_labels):
|
||||
def create_and_check_model(self, config, input_ids, input_mask, pixel_values):
|
||||
model = GitModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -310,7 +300,7 @@ class GitModelTester:
|
||||
result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values, token_labels):
|
||||
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values):
|
||||
model = GitForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -331,6 +321,24 @@ class GitModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertTrue(result.loss.item() > 0)
|
||||
|
||||
def _test_beam_search_generate(self, config, input_ids, input_mask, pixel_values):
|
||||
model = GitForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# generate
|
||||
generated_ids = model.generate(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
pixel_values=pixel_values,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
@@ -339,7 +347,6 @@ class GitModelTester:
|
||||
input_ids,
|
||||
input_mask,
|
||||
pixel_values,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
@@ -352,7 +359,7 @@ class GitModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class GitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
|
||||
@@ -383,47 +390,42 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester._test_beam_search_generate(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
if model_class in [
|
||||
*get_values(MODEL_MAPPING),
|
||||
*get_values(MODEL_FOR_BACKBONE_MAPPING),
|
||||
]:
|
||||
continue
|
||||
|
||||
print("Model class:", model_class)
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
for k, v in inputs.items():
|
||||
print(k, v.shape)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in GIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = GitModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
|
||||
Reference in New Issue
Block a user