Generate: model_kwargs can also be an input to prepare_inputs_for_generation (#20353)
This commit is contained in:
@@ -194,9 +194,9 @@ class FlaxGenerationMixin:
|
|||||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||||
unused_model_args = []
|
unused_model_args = []
|
||||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
||||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||||
if "kwargs" in model_args:
|
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||||
model_args |= set(inspect.signature(self.__call__).parameters)
|
model_args |= set(inspect.signature(self.__call__).parameters)
|
||||||
for key, value in model_kwargs.items():
|
for key, value in model_kwargs.items():
|
||||||
if value is not None and key not in model_args:
|
if value is not None and key not in model_args:
|
||||||
|
|||||||
@@ -1445,9 +1445,9 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
unused_model_args = []
|
unused_model_args = []
|
||||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
||||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||||
if "kwargs" in model_args:
|
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||||
model_args |= set(inspect.signature(self.call).parameters)
|
model_args |= set(inspect.signature(self.call).parameters)
|
||||||
for key, value in model_kwargs.items():
|
for key, value in model_kwargs.items():
|
||||||
if value is not None and key not in model_args:
|
if value is not None and key not in model_args:
|
||||||
|
|||||||
@@ -981,9 +981,9 @@ class GenerationMixin:
|
|||||||
|
|
||||||
unused_model_args = []
|
unused_model_args = []
|
||||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
|
||||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
|
||||||
if "kwargs" in model_args:
|
if "kwargs" in model_args or "model_kwargs" in model_args:
|
||||||
model_args |= set(inspect.signature(self.forward).parameters)
|
model_args |= set(inspect.signature(self.forward).parameters)
|
||||||
for key, value in model_kwargs.items():
|
for key, value in model_kwargs.items():
|
||||||
if value is not None and key not in model_args:
|
if value is not None and key not in model_args:
|
||||||
|
|||||||
@@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(max_score_diff < 1e-5)
|
self.assertTrue(max_score_diff < 1e-5)
|
||||||
|
|
||||||
def test_validate_generation_inputs(self):
|
def test_validate_generation_inputs(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
||||||
|
|
||||||
encoder_input_str = "Hello world"
|
encoder_input_str = "Hello world"
|
||||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||||
@@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, "foo"):
|
with self.assertRaisesRegex(ValueError, "foo"):
|
||||||
fake_model_kwargs = {"foo": "bar"}
|
fake_model_kwargs = {"foo": "bar"}
|
||||||
model.generate(input_ids, **fake_model_kwargs)
|
model.generate(input_ids, **fake_model_kwargs)
|
||||||
|
|
||||||
|
# However, valid model_kwargs are accepted
|
||||||
|
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
||||||
|
model.generate(input_ids, **valid_model_kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user