Remove subclass for sortish sampler (#9907)
* Remove subclass for sortish sampler * Use old Seq2SeqTrainer in script * Styling
This commit is contained in:
@@ -20,6 +20,8 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
@@ -27,8 +29,6 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
||||
@@ -286,6 +286,7 @@ def main():
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_args=data_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(
|
||||
@@ -323,9 +324,7 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
|
||||
)
|
||||
metrics = trainer.evaluate(metric_key_prefix="val")
|
||||
metrics["val_n_objs"] = data_args.n_val
|
||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||
|
||||
@@ -337,12 +336,7 @@ def main():
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_output = trainer.predict(
|
||||
test_dataset=test_dataset,
|
||||
metric_key_prefix="test",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.eval_beams,
|
||||
)
|
||||
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
|
||||
metrics = test_output.metrics
|
||||
metrics["test_n_objs"] = data_args.n_test
|
||||
|
||||
|
||||
Reference in New Issue
Block a user