Trainer: fail early in the presence of an unsavable generation_config (#29675)
This commit is contained in:
@@ -181,3 +181,22 @@ class Seq2seqTrainerTester(TestCasePlus):
|
||||
assert (
|
||||
metrics["eval_samples"] == dataset_len * num_return_sequences
|
||||
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"
|
||||
|
||||
@require_torch
|
||||
def test_bad_generation_config_fail_early(self):
|
||||
# Tests that a bad geneartion config causes the trainer to fail early
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
|
||||
gen_config = GenerationConfig(do_sample=False, top_p=0.9) # bad: top_p is not compatible with do_sample=False
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True, generation_config=gen_config)
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
_ = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda x: {"samples": x[0].shape[0]},
|
||||
)
|
||||
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))
|
||||
|
||||
Reference in New Issue
Block a user