Generate: model_kwargs can also be an input to prepare_inputs_for_generation (#20353)
This commit is contained in:
@@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(max_score_diff < 1e-5)
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
@@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
# However, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user