Generate: validate model_kwargs on TF (and catch typos in generate arguments) (#18651)

This commit is contained in:
Joao Gante
2022-09-02 16:25:26 +01:00
committed by GitHub
parent c5be7cae59
commit 9196f48b95
4 changed files with 214 additions and 139 deletions

View File

@@ -2704,8 +2704,8 @@ class GenerationIntegrationTests(unittest.TestCase):
model.generate(input_ids, force_words_ids=[[[-1]]])
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids