VLM generate: tests can't generate image/video tokens (#33623)
This commit is contained in:
@@ -132,7 +132,7 @@ class GenerationTesterMixin:
|
||||
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"repetition_penalty": 1.2,
|
||||
@@ -146,6 +146,17 @@ class GenerationTesterMixin:
|
||||
"temperature": 0.7,
|
||||
}
|
||||
)
|
||||
# TODO (joao, raushan): see this comment for a long-term fix
|
||||
# https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
|
||||
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
|
||||
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
|
||||
if config is not None:
|
||||
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
|
||||
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
|
||||
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
|
||||
if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([video_token_index])
|
||||
|
||||
return logits_processor_kwargs
|
||||
|
||||
@@ -211,7 +222,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -246,7 +257,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -281,7 +292,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -316,7 +327,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -350,7 +361,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -385,7 +396,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -424,7 +435,7 @@ class GenerationTesterMixin:
|
||||
"top_k": 5,
|
||||
}
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -2052,6 +2063,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
|
||||
Reference in New Issue
Block a user