Revert frozen training arguments (#25903)
* Revert frozen training arguments * TODO
This commit is contained in:
@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -675,10 +674,14 @@ def main():
|
||||
return result
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
if training_args.generation_max_length is None:
|
||||
training_args = dataclasses.replace(training_args, generation_max_length=data_args.val_max_target_length)
|
||||
if training_args.generation_num_beams is None:
|
||||
training_args = dataclasses.replace(training_args, generation_num_beams=data_args.num_beams)
|
||||
training_args.generation_max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
training_args.generation_num_beams = (
|
||||
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
|
||||
Reference in New Issue
Block a user