Generate: validate model_kwargs on TF (and catch typos in generate arguments) (#18651)
This commit is contained in:
@@ -2704,8 +2704,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
Reference in New Issue
Block a user