Do not remove half seq length in generation tests (#30016)
* remove seq length from generation tests * style and quality * [test_all] & PR suggestion Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * [test all] remove unused variables --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b4fd49b6c5
commit
b1cd48740e
@@ -82,43 +82,35 @@ class GenerationTesterMixin:
|
||||
model_tester = None
|
||||
all_generative_model_classes = ()
|
||||
input_name = "input_ids"
|
||||
max_new_tokens = 3
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
input_ids = input_ids[:batch_size]
|
||||
|
||||
# generate max 3 tokens
|
||||
if config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
else:
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
return config, input_ids, attention_mask, max_length
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
max_length=None,
|
||||
):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"repetition_penalty": 1.2,
|
||||
"remove_invalid_values": True,
|
||||
@@ -185,7 +177,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@@ -196,7 +187,6 @@ class GenerationTesterMixin:
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@@ -204,7 +194,7 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
@@ -221,7 +211,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
num_return_sequences,
|
||||
logits_warper_kwargs,
|
||||
process_kwargs,
|
||||
@@ -237,7 +226,7 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
num_return_sequences=num_return_sequences,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
@@ -256,7 +245,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
@@ -269,7 +257,7 @@ class GenerationTesterMixin:
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
@@ -287,7 +275,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
beam_kwargs,
|
||||
logits_warper_kwargs,
|
||||
output_scores=False,
|
||||
@@ -301,7 +288,7 @@ class GenerationTesterMixin:
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
@@ -319,7 +306,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
@@ -332,7 +318,7 @@ class GenerationTesterMixin:
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
@@ -350,7 +336,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
@@ -364,7 +349,7 @@ class GenerationTesterMixin:
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
@@ -383,7 +368,6 @@ class GenerationTesterMixin:
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
max_length,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@@ -399,7 +383,6 @@ class GenerationTesterMixin:
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@@ -407,7 +390,7 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
@@ -422,18 +405,19 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_greedy_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
|
||||
)
|
||||
output_generate = self._greedy_generate(model=model, input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@@ -441,7 +425,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -450,20 +433,21 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
self._check_outputs(output_generate, input_ids, model.config)
|
||||
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
@@ -475,7 +459,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -483,57 +466,54 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_return_sequences=1,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
)
|
||||
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_return_sequences=2,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
@@ -545,30 +525,28 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
@@ -576,36 +554,33 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
@@ -615,15 +590,16 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
@@ -631,20 +607,16 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -656,7 +628,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
@@ -666,7 +637,10 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
@@ -681,7 +655,7 @@ class GenerationTesterMixin:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).eval()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -691,32 +665,32 @@ class GenerationTesterMixin:
|
||||
new_model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
num_beams=2,
|
||||
)
|
||||
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||
@@ -724,7 +698,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
@@ -733,23 +706,19 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
@@ -760,21 +729,22 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _, max_length = self._get_input_ids_and_config()
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
@@ -788,22 +758,20 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
def test_group_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
@@ -812,11 +780,13 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
@@ -825,26 +795,24 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
@@ -852,7 +820,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
@@ -862,15 +829,16 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
@@ -879,16 +847,14 @@ class GenerationTesterMixin:
|
||||
@is_flaky()
|
||||
def test_constrained_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
max_length = 20
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
@@ -905,12 +871,16 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
@@ -921,39 +891,37 @@ class GenerationTesterMixin:
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
max_length = 20
|
||||
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
|
||||
|
||||
output_generate = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 20
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
max_length,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
@@ -969,7 +937,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
@@ -981,15 +948,16 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
@@ -1000,7 +968,7 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1011,9 +979,12 @@ class GenerationTesterMixin:
|
||||
# test old generation output for backwards compatibility
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._contrastive_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask
|
||||
)
|
||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -1021,7 +992,7 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1034,7 +1005,6 @@ class GenerationTesterMixin:
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -1042,7 +1012,10 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
@@ -1053,7 +1026,7 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
|
||||
self.skipTest("TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1070,7 +1043,7 @@ class GenerationTesterMixin:
|
||||
top_k=4,
|
||||
penalty_alpha=0.6,
|
||||
low_memory=True,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
@@ -1079,7 +1052,7 @@ class GenerationTesterMixin:
|
||||
top_k=4,
|
||||
penalty_alpha=0.6,
|
||||
low_memory=False,
|
||||
max_length=max_length,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
@@ -1102,7 +1075,7 @@ class GenerationTesterMixin:
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
|
||||
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
config.use_cache = True
|
||||
@@ -1150,7 +1123,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1213,7 +1186,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1273,7 +1246,7 @@ class GenerationTesterMixin:
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1311,7 +1284,7 @@ class GenerationTesterMixin:
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
# We want to test only encoder-decoder models
|
||||
if not config.is_encoder_decoder:
|
||||
continue
|
||||
@@ -1358,7 +1331,7 @@ class GenerationTesterMixin:
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _, _, _ = self._get_input_ids_and_config()
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
else:
|
||||
@@ -1391,7 +1364,7 @@ class GenerationTesterMixin:
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
@@ -1485,7 +1458,7 @@ class GenerationTesterMixin:
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, _, _ = self._get_input_ids_and_config()
|
||||
config, input_ids, _ = self._get_input_ids_and_config()
|
||||
|
||||
# Ignore:
|
||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||
@@ -1616,7 +1589,7 @@ class GenerationTesterMixin:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest("This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user