Revert frozen training arguments (#25903)

* Revert frozen training arguments

* TODO
This commit is contained in:
Zach Mueller
2023-09-01 11:24:12 -04:00
committed by GitHub
parent 69c5b8f186
commit be0e189bd3
9 changed files with 31 additions and 58 deletions

View File

@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import dataclasses
import logging
import os
import sys
@@ -675,10 +674,14 @@ def main():
return result
# Override the decoding parameters of Seq2SeqTrainer
if training_args.generation_max_length is None:
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
if training_args.generation_num_beams is None:
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(