Trainer: fail early in the presence of an unsavable generation_config (#29675)

This commit is contained in:
Joao Gante
2024-03-15 12:59:10 +00:00
committed by GitHub
parent f62407f788
commit c47fcd0830
3 changed files with 52 additions and 18 deletions

View File

@@ -181,3 +181,22 @@ class Seq2seqTrainerTester(TestCasePlus):
assert (
metrics["eval_samples"] == dataset_len * num_return_sequences
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"
@require_torch
def test_bad_generation_config_fail_early(self):
# Tests that a bad geneartion config causes the trainer to fail early
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
gen_config = GenerationConfig(do_sample=False, top_p=0.9) # bad: top_p is not compatible with do_sample=False
training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True, generation_config=gen_config)
with self.assertRaises(ValueError) as exc:
_ = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)
self.assertIn("The loaded generation config instance is invalid", str(exc.exception))