Trainer: delegate default generation values to generation_config (#25987)
This commit is contained in:
@@ -46,12 +46,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
**gen_kwargs,
|
||||
) -> Dict[str, float]:
|
||||
gen_kwargs = gen_kwargs.copy()
|
||||
gen_kwargs["max_length"] = (
|
||||
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
|
||||
)
|
||||
gen_kwargs["num_beams"] = (
|
||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
||||
)
|
||||
|
||||
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
|
||||
# training args
|
||||
if gen_kwargs.get("max_length") is None and self.args.generation_max_length is not None:
|
||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||
|
||||
Reference in New Issue
Block a user