Trainer: delegate default generation values to generation_config (#25987)
This commit is contained in:
@@ -46,12 +46,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
**gen_kwargs,
|
**gen_kwargs,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
gen_kwargs = gen_kwargs.copy()
|
gen_kwargs = gen_kwargs.copy()
|
||||||
gen_kwargs["max_length"] = (
|
|
||||||
gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
|
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
|
||||||
)
|
# training args
|
||||||
gen_kwargs["num_beams"] = (
|
if gen_kwargs.get("max_length") is None and self.args.generation_max_length is not None:
|
||||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||||
)
|
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||||
|
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||||
self._gen_kwargs = gen_kwargs
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||||
|
|||||||
@@ -149,11 +149,17 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
gen_kwargs = gen_kwargs.copy()
|
gen_kwargs = gen_kwargs.copy()
|
||||||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
|
||||||
|
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
|
||||||
|
# training args
|
||||||
|
if (
|
||||||
|
gen_kwargs.get("max_length") is None
|
||||||
|
and gen_kwargs.get("max_new_tokens") is None
|
||||||
|
and self.args.generation_max_length is not None
|
||||||
|
):
|
||||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||||
gen_kwargs["num_beams"] = (
|
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||||
)
|
|
||||||
self._gen_kwargs = gen_kwargs
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
@@ -206,11 +212,17 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
gen_kwargs = gen_kwargs.copy()
|
gen_kwargs = gen_kwargs.copy()
|
||||||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
|
||||||
|
# Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
|
||||||
|
# training args
|
||||||
|
if (
|
||||||
|
gen_kwargs.get("max_length") is None
|
||||||
|
and gen_kwargs.get("max_new_tokens") is None
|
||||||
|
and self.args.generation_max_length is not None
|
||||||
|
):
|
||||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||||
gen_kwargs["num_beams"] = (
|
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||||
)
|
|
||||||
self._gen_kwargs = gen_kwargs
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
@@ -256,16 +268,14 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
|
|
||||||
# XXX: adapt synced_gpus for fairscale as well
|
# XXX: adapt synced_gpus for fairscale as well
|
||||||
# Priority (handled in generate):
|
# Priority (handled in generate):
|
||||||
# gen_kwargs > model.generation_config > default GenerationConfig()
|
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
|
||||||
|
|
||||||
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
|
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
|
||||||
gen_kwargs = self._gen_kwargs.copy()
|
gen_kwargs = self._gen_kwargs.copy()
|
||||||
|
if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
|
||||||
|
gen_kwargs.pop("num_beams")
|
||||||
|
if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
|
||||||
|
gen_kwargs.pop("max_length")
|
||||||
|
|
||||||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
|
||||||
gen_kwargs["max_length"] = self.model.config.max_length
|
|
||||||
gen_kwargs["num_beams"] = (
|
|
||||||
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
|
|
||||||
)
|
|
||||||
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
|
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
|
||||||
gen_kwargs["synced_gpus"] = (
|
gen_kwargs["synced_gpus"] = (
|
||||||
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
|
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
|
||||||
|
|||||||
Reference in New Issue
Block a user