Generate: validate model_kwargs (and catch typos in generate arguments) (#18261)
* validate generate model_kwargs * generate tests -- not all models have an attn mask
This commit is contained in:
@@ -841,6 +841,29 @@ class GenerationMixin:
|
||||
|
||||
return transition_scores
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
# Excludes arguments that are handled before calling any model function
|
||||
if self.config.is_encoder_decoder:
|
||||
for key in ["decoder_input_ids"]:
|
||||
model_kwargs.pop(key, None)
|
||||
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.forward).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
raise ValueError(
|
||||
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
||||
" generate arguments will also show up in this list)"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@@ -1120,6 +1143,9 @@ class GenerationMixin:
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
|
||||
```"""
|
||||
# 0. Validate model kwargs
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 1. Set generation parameters if not already defined
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
|
||||
@@ -75,21 +75,25 @@ class GenerationTesterMixin:
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
# TransfoXL has no attention mask
|
||||
if "transfoxl" in config.__class__.__name__.lower():
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
|
||||
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@staticmethod
|
||||
@@ -252,10 +256,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
@@ -265,6 +268,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -278,16 +282,17 @@ class GenerationTesterMixin:
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_greedy = model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_processor=logits_processor,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
return output_greedy, output_generate
|
||||
|
||||
@@ -308,13 +313,13 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
attention_mask=attention_mask,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -327,7 +332,7 @@ class GenerationTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@@ -336,18 +341,16 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
||||
elif attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
input_ids.repeat_interleave(num_return_sequences, dim=0),
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
@@ -356,6 +359,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
return output_sample, output_generate
|
||||
|
||||
@@ -374,9 +378,9 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
@@ -386,12 +390,13 @@ class GenerationTesterMixin:
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@@ -400,23 +405,22 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
elif attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_beam_search = model.beam_search(
|
||||
input_ids_clone,
|
||||
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
return output_generate, output_beam_search
|
||||
|
||||
@@ -437,9 +441,9 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
@@ -449,6 +453,7 @@ class GenerationTesterMixin:
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
|
||||
kwargs = {}
|
||||
@@ -462,7 +467,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
else:
|
||||
elif attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
||||
|
||||
# prevent flaky generation test failures
|
||||
@@ -471,11 +476,11 @@ class GenerationTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_beam_sample = model.beam_sample(
|
||||
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
@@ -483,6 +488,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate, output_beam_sample
|
||||
@@ -502,9 +508,9 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
@@ -514,12 +520,13 @@ class GenerationTesterMixin:
|
||||
remove_invalid_values=True,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@@ -528,23 +535,22 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
elif attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_group_beam_search = model.group_beam_search(
|
||||
input_ids_clone,
|
||||
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
@@ -564,9 +570,9 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
@@ -577,12 +583,13 @@ class GenerationTesterMixin:
|
||||
constraints=constraints,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@@ -591,23 +598,22 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
elif attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(constrained_beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_group_beam_search = model.constrained_beam_search(
|
||||
input_ids_clone,
|
||||
input_ids.repeat_interleave(constrained_beam_scorer.num_beams, dim=0),
|
||||
constrained_beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
return output_generate, output_group_beam_search
|
||||
|
||||
@@ -1044,12 +1050,7 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
def test_group_beam_search_generate(self):
|
||||
@@ -2052,7 +2053,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and must not be used together.
|
||||
with self.assertRaises(ValueError):
|
||||
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -2699,3 +2700,19 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user