VLMs: enable generation tests (#33533)
* add tests * fix whisper * update * nit * add qwen2-vl * more updates! * better this way * fix this one * fix more tests * fix final tests, hope so * fix led * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * pr comments * not pass pixels and extra for low-mem tests, very flaky because of visio tower --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e40bb4845e
commit
d7975a5874
@@ -98,10 +98,22 @@ class GenerationTesterMixin:
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
# TODO: @raushan or @gante, use `model.main_input_name` as the main input instead of relyinn on `input_ids`
|
||||
input_ids = inputs_dict.pop(self.input_name)[:batch_size, :]
|
||||
inputs_dict.pop("attention_mask", None)
|
||||
|
||||
input_ids = input_ids[:batch_size]
|
||||
# we don't want encoder-decoder models to start from filled decoder ids
|
||||
inputs_dict.pop("decoder_input_ids", None)
|
||||
inputs_dict.pop("decoder_attention_mask", None)
|
||||
|
||||
# we'll set cache use in each test differently
|
||||
inputs_dict.pop("use_cache", None)
|
||||
|
||||
inputs_dict = {
|
||||
k: v[:batch_size, ...]
|
||||
for k, v in inputs_dict.items()
|
||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||
}
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
if isinstance(config.eos_token_id, int):
|
||||
@@ -118,7 +130,7 @@ class GenerationTesterMixin:
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {
|
||||
@@ -191,6 +203,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@@ -213,6 +226,7 @@ class GenerationTesterMixin:
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -222,6 +236,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
num_return_sequences,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
@@ -247,6 +262,7 @@ class GenerationTesterMixin:
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -256,6 +272,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
beam_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
@@ -279,6 +296,7 @@ class GenerationTesterMixin:
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -288,6 +306,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
beam_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
@@ -312,6 +331,7 @@ class GenerationTesterMixin:
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -321,6 +341,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
beam_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
@@ -344,6 +365,7 @@ class GenerationTesterMixin:
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -353,6 +375,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
output_scores=False,
|
||||
@@ -378,6 +401,7 @@ class GenerationTesterMixin:
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -387,6 +411,7 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_dict,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@@ -415,6 +440,7 @@ class GenerationTesterMixin:
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
@@ -422,10 +448,12 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(model=model, input_ids=input_ids, attention_mask=attention_mask)
|
||||
output_generate = self._greedy_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, inputs_dict=inputs_dict
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -435,13 +463,14 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -466,7 +495,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
@@ -479,6 +508,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -497,13 +527,14 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
num_return_sequences=1,
|
||||
)
|
||||
|
||||
@@ -515,13 +546,14 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
num_return_sequences=2,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
@@ -547,7 +579,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
@@ -556,6 +588,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
@@ -567,7 +600,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -575,6 +608,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
@@ -602,7 +636,7 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
@@ -618,6 +652,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
@@ -647,7 +682,7 @@ class GenerationTesterMixin:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).eval()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -659,12 +694,13 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
num_beams=2,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -672,6 +708,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
@@ -680,28 +717,34 @@ class GenerationTesterMixin:
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters)
|
||||
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
||||
# code is up to date with our most recent standards
|
||||
if (
|
||||
"inputs_embeds" in prepare_inputs_for_generation_args
|
||||
and "cache_positions" in prepare_inputs_for_generation_args
|
||||
):
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||
output_generate2 = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
# for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly
|
||||
# no quick fix available, since obtaining image embeddings step is very model-specific
|
||||
if any(name in model.__class__.__name__.lower() for name in ("blip", "llava", "paligemma")):
|
||||
prepare_inputs_for_generation_args = set(
|
||||
inspect.signature(model.prepare_inputs_for_generation).parameters
|
||||
)
|
||||
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
||||
# code is up to date with our most recent standards
|
||||
if (
|
||||
"inputs_embeds" in prepare_inputs_for_generation_args
|
||||
and "cache_positions" in prepare_inputs_for_generation_args
|
||||
):
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||
output_generate2 = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict={},
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
||||
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -710,6 +753,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
@@ -736,7 +780,7 @@ class GenerationTesterMixin:
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
config, _, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
@@ -758,7 +802,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_group_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
@@ -767,6 +811,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -781,6 +826,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -791,7 +837,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
@@ -799,6 +845,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
beam_kwargs=beam_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
@@ -827,7 +874,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_constrained_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
@@ -845,6 +892,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
@@ -870,6 +918,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
@@ -885,7 +934,7 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
@@ -902,6 +951,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
output_scores=True,
|
||||
@@ -937,7 +987,7 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -947,7 +997,11 @@ class GenerationTesterMixin:
|
||||
# test old generation output for backwards compatibility
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._contrastive_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
use_cache=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -964,7 +1018,7 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -976,6 +1030,7 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_dict=inputs_dict,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -1003,7 +1058,7 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
self.skipTest(reason="TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1021,6 +1076,7 @@ class GenerationTesterMixin:
|
||||
low_memory=True,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
**inputs_dict,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
@@ -1031,6 +1087,7 @@ class GenerationTesterMixin:
|
||||
low_memory=False,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
**inputs_dict,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
@@ -1055,7 +1112,7 @@ class GenerationTesterMixin:
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
|
||||
config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
config.use_cache = True
|
||||
@@ -1065,7 +1122,12 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
low_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
|
||||
input_ids,
|
||||
max_new_tokens=8,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
low_memory=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
@@ -1114,7 +1176,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1140,7 +1202,9 @@ class GenerationTesterMixin:
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
output_greedy = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
|
||||
# test with the same assistant model or randomly init one
|
||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||
@@ -1152,7 +1216,9 @@ class GenerationTesterMixin:
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
output_assisted = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
@@ -1187,7 +1253,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1214,10 +1280,14 @@ class GenerationTesterMixin:
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
output_greedy = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
|
||||
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
|
||||
output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
output_prompt_lookup = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
@@ -1239,7 +1309,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
|
||||
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
@@ -1267,7 +1337,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
generation_kwargs.update({"dola_layers": "low"})
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs, **inputs_dict)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
|
||||
|
||||
@pytest.mark.generate
|
||||
@@ -1296,7 +1366,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1326,9 +1396,11 @@ class GenerationTesterMixin:
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
output_assisted = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
|
||||
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||
self._check_outputs(output_assisted, input_ids, config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||
@@ -1364,7 +1436,7 @@ class GenerationTesterMixin:
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
# We want to test only encoder-decoder models
|
||||
if not config.is_encoder_decoder:
|
||||
continue
|
||||
@@ -1394,6 +1466,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
**inputs_dict,
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
@@ -1416,7 +1489,7 @@ class GenerationTesterMixin:
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
config, _, _, _ = self._get_input_ids_and_config()
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
else:
|
||||
@@ -1449,7 +1522,7 @@ class GenerationTesterMixin:
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
@@ -1462,7 +1535,9 @@ class GenerationTesterMixin:
|
||||
|
||||
# With left-padding (length 32)
|
||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||
pad_token_id = config.pad_token_id if getattr(config, "pad_token_id") is not None else 0
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
@@ -1550,7 +1625,7 @@ class GenerationTesterMixin:
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, _ = self._get_input_ids_and_config()
|
||||
config, input_ids, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# Ignore:
|
||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||
@@ -1572,25 +1647,23 @@ class GenerationTesterMixin:
|
||||
continue
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(input_ids)
|
||||
self.assertEqual(outputs_from_ids.shape, (2, 20))
|
||||
outputs_from_ids = model.generate(input_ids, max_new_tokens=5)
|
||||
self.assertEqual(outputs_from_ids.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, max_new_tokens=5)
|
||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs
|
||||
torch.manual_seed(0)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds, max_new_tokens=5)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1]
|
||||
)
|
||||
outputs_from_embeds_wo_ids = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=5)
|
||||
self.assertListEqual(
|
||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||
outputs_from_embeds_wo_ids.tolist(),
|
||||
@@ -1607,7 +1680,7 @@ class GenerationTesterMixin:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(reason="This model does not support the static cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
@@ -1621,27 +1694,30 @@ class GenerationTesterMixin:
|
||||
max_cache_len = 30
|
||||
|
||||
# here we force to not stop at eos and go until max-length
|
||||
model.generation_config.eos_token_id = model.config.eos_token_id = -1
|
||||
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
||||
generation_kwargs = {
|
||||
"max_length": max_cache_len,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
text_config = model.config.get_text_config()
|
||||
head_dim = (
|
||||
model.config.head_dim
|
||||
if hasattr(model.config, "head_dim")
|
||||
else model.config.hidden_size // model.config.num_attention_heads
|
||||
text_config.head_dim
|
||||
if hasattr(text_config, "head_dim")
|
||||
else text_config.hidden_size // text_config.num_attention_heads
|
||||
)
|
||||
num_key_value_heads = (
|
||||
model.config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else model.config.num_key_value_heads
|
||||
text_config.num_attention_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is None
|
||||
else text_config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
|
||||
outputs = model.generate(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
|
||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
@@ -1742,7 +1818,7 @@ class GenerationTesterMixin:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
@@ -1757,7 +1833,9 @@ class GenerationTesterMixin:
|
||||
# Sets seed before calling `generate` for the case with do_sample=True
|
||||
seed = torch.randint(0, 1000000, (1,)).item()
|
||||
set_seed(seed)
|
||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
legacy_results = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
set_seed(seed)
|
||||
if config.is_encoder_decoder:
|
||||
cache_cls = EncoderDecoderCache
|
||||
@@ -1766,7 +1844,11 @@ class GenerationTesterMixin:
|
||||
cache_cls = DynamicCache
|
||||
past_key_values = cache_cls()
|
||||
new_results = model.generate(
|
||||
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**generation_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||
@@ -1810,7 +1892,7 @@ class GenerationTesterMixin:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(reason="This model does not support the static cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
@@ -1838,7 +1920,7 @@ class GenerationTesterMixin:
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
|
||||
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
@@ -1852,7 +1934,7 @@ class GenerationTesterMixin:
|
||||
if not model_class._supports_quantized_cache:
|
||||
self.skipTest(reason="This model does not support the quantized cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@@ -1865,7 +1947,7 @@ class GenerationTesterMixin:
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict)
|
||||
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
|
||||
|
||||
# passing past key values of different type should raise Error
|
||||
@@ -1931,7 +2013,7 @@ class GenerationTesterMixin:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -1946,10 +2028,12 @@ class GenerationTesterMixin:
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
||||
)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
without_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
|
||||
)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
@@ -1959,7 +2043,7 @@ class GenerationTesterMixin:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config(batch_size=1)
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -1976,10 +2060,12 @@ class GenerationTesterMixin:
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
||||
)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
without_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **inputs_dict, **generation_kwargs
|
||||
)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
|
||||
Reference in New Issue
Block a user