Generate: validate model_kwargs (and catch typos in generate arguments) (#18261)

* validate generate model_kwargs

* generate tests -- not all models have an attn mask
This commit is contained in:
Joao Gante
2022-08-12 14:53:51 +01:00
committed by GitHub
parent 2156619f10
commit ed1924e801
2 changed files with 91 additions and 48 deletions

View File

@@ -75,21 +75,25 @@ class GenerationTesterMixin:
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
# TransfoXL has no attention mask
if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
else:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
return config, input_ids, attention_mask, max_length
@staticmethod
@@ -252,10 +256,9 @@ class GenerationTesterMixin:
)
kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
num_beams=1,
max_length=max_length,
@@ -265,6 +268,7 @@ class GenerationTesterMixin:
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
)
if model.config.is_encoder_decoder:
@@ -278,16 +282,17 @@ class GenerationTesterMixin:
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search(
input_ids,
max_length=max_length,
attention_mask=attention_mask,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_greedy, output_generate
@@ -308,13 +313,13 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -327,7 +332,7 @@ class GenerationTesterMixin:
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
@@ -336,18 +341,16 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
@@ -356,6 +359,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_sample, output_generate
@@ -374,9 +378,9 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
@@ -386,12 +390,13 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
@@ -400,23 +405,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_search = model.beam_search(
input_ids_clone,
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_beam_search
@@ -437,9 +441,9 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=max_length,
output_scores=output_scores,
@@ -449,6 +453,7 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**beam_kwargs,
**logits_warper_kwargs,
**model_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
kwargs = {}
@@ -462,7 +467,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
else:
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
# prevent flaky generation test failures
@@ -471,11 +476,11 @@ class GenerationTesterMixin:
torch.manual_seed(0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores,
@@ -483,6 +488,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_beam_sample
@@ -502,9 +508,9 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
@@ -514,12 +520,13 @@ class GenerationTesterMixin:
remove_invalid_values=True,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
@@ -528,23 +535,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.group_beam_search(
input_ids_clone,
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_group_beam_search
@@ -564,9 +570,9 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
@@ -577,12 +583,13 @@ class GenerationTesterMixin:
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
@@ -591,23 +598,22 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_group_beam_search = model.constrained_beam_search(
input_ids_clone,
input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0),
constrained_beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
)
return output_generate, output_group_beam_search
@@ -1044,12 +1050,7 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
remove_invalid_values=True,
)
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
self.assertIsNotNone(output_ids_generate)
def test_group_beam_search_generate(self):
@@ -2052,7 +2053,7 @@ class GenerationIntegrationTests(unittest.TestCase):
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_encoder_decoder_generate_with_inputs_embeds(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
@@ -2699,3 +2700,19 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)