Generate: model_kwargs can also be an input to prepare_inputs_for_generation (#20353)

This commit is contained in:
Joao Gante
2022-11-21 16:20:27 +00:00
committed by GitHub
parent d21c97cc0f
commit 4cf38148dc
4 changed files with 15 additions and 11 deletions

View File

@@ -3007,8 +3007,8 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(max_score_diff < 1e-5)
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-roberta")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
@@ -3021,3 +3021,7 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)
# However, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
model.generate(input_ids, **valid_model_kwargs)