Fix git model for generate with beam search. (#21071)
* Fix git model for generate with beam search. * Update comment * Fix bug on multi batch * Add generate tests * Clean up tests * Fix style Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -1264,6 +1264,11 @@ class GitModel(GitPreTrainedModel):
|
|||||||
device=embedding_output.device,
|
device=embedding_output.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Repeat visual features to match embedding batch size.
|
||||||
|
projected_visual_features = projected_visual_features.repeat(
|
||||||
|
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
# concatenate patch token and text token embeddings
|
# concatenate patch token and text token embeddings
|
||||||
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
|
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from transformers import GitConfig, GitProcessor, GitVisionConfig, is_torch_avai
|
|||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
@@ -29,14 +30,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import (
|
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, GitForCausalLM, GitModel, GitVisionModel
|
||||||
MODEL_FOR_BACKBONE_MAPPING,
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
||||||
MODEL_MAPPING,
|
|
||||||
GitForCausalLM,
|
|
||||||
GitModel,
|
|
||||||
GitVisionModel,
|
|
||||||
)
|
|
||||||
from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
@@ -259,13 +253,9 @@ class GitModelTester:
|
|||||||
|
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
|
||||||
token_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels)
|
|
||||||
|
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
|
|
||||||
return config, input_ids, input_mask, pixel_values, token_labels
|
return config, input_ids, input_mask, pixel_values
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
"""
|
"""
|
||||||
@@ -292,7 +282,7 @@ class GitModelTester:
|
|||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, input_ids, input_mask, pixel_values, token_labels):
|
def create_and_check_model(self, config, input_ids, input_mask, pixel_values):
|
||||||
model = GitModel(config=config)
|
model = GitModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -310,7 +300,7 @@ class GitModelTester:
|
|||||||
result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
|
result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values, token_labels):
|
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values):
|
||||||
model = GitForCausalLM(config=config)
|
model = GitForCausalLM(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -331,6 +321,24 @@ class GitModelTester:
|
|||||||
self.parent.assertEqual(result.loss.shape, ())
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
self.parent.assertTrue(result.loss.item() > 0)
|
self.parent.assertTrue(result.loss.item() > 0)
|
||||||
|
|
||||||
|
def _test_beam_search_generate(self, config, input_ids, input_mask, pixel_values):
|
||||||
|
model = GitForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# generate
|
||||||
|
generated_ids = model.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
do_sample=False,
|
||||||
|
max_length=20,
|
||||||
|
num_beams=2,
|
||||||
|
num_return_sequences=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
@@ -339,7 +347,6 @@ class GitModelTester:
|
|||||||
input_ids,
|
input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
token_labels,
|
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
|
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
@@ -352,7 +359,7 @@ class GitModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GitModelTest(ModelTesterMixin, unittest.TestCase):
|
class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
|
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
|
||||||
@@ -383,47 +390,42 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_causal_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_beam_search_generate(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester._test_beam_search_generate(*config_and_inputs)
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
def test_model_various_embeddings(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||||
config_and_inputs[0].position_embedding_type = type
|
config_and_inputs[0].position_embedding_type = type
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_causal_lm(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_training(self):
|
|
||||||
if not self.model_tester.is_training:
|
|
||||||
return
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.return_dict = True
|
|
||||||
|
|
||||||
if model_class in [
|
|
||||||
*get_values(MODEL_MAPPING),
|
|
||||||
*get_values(MODEL_FOR_BACKBONE_MAPPING),
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("Model class:", model_class)
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.train()
|
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
|
||||||
for k, v in inputs.items():
|
|
||||||
print(k, v.shape)
|
|
||||||
loss = model(**inputs).loss
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in GIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in GIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
model = GitModel.from_pretrained(model_name)
|
model = GitModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="GIT has pixel values as additional input")
|
||||||
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
|
|||||||
Reference in New Issue
Block a user