From 115d97dd2f752880715cd01aa915286e3d9a5442 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 1 Feb 2021 08:06:32 -0500 Subject: [PATCH] Remove subclass for sortish sampler (#9907) * Remove subclass for sortish sampler * Use old Seq2SeqTrainer in script * Styling --- examples/seq2seq/finetune_trainer.py | 16 +++++----------- src/transformers/trainer_seq2seq.py | 22 ---------------------- 2 files changed, 5 insertions(+), 33 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 89dd80395a..37573e50ba 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -20,6 +20,8 @@ from dataclasses import dataclass, field from typing import Optional import transformers +from seq2seq_trainer import Seq2SeqTrainer +from seq2seq_training_args import Seq2SeqTrainingArguments from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, @@ -27,8 +29,6 @@ from transformers import ( HfArgumentParser, MBartTokenizer, MBartTokenizerFast, - Seq2SeqTrainer, - Seq2SeqTrainingArguments, set_seed, ) from transformers.trainer_utils import EvaluationStrategy, is_main_process @@ -286,6 +286,7 @@ def main(): trainer = Seq2SeqTrainer( model=model, args=training_args, + data_args=data_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator( @@ -323,9 +324,7 @@ def main(): if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate( - metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams - ) + metrics = trainer.evaluate(metric_key_prefix="val") metrics["val_n_objs"] = data_args.n_val metrics["val_loss"] = round(metrics["val_loss"], 4) @@ -337,12 +336,7 @@ def main(): if training_args.do_predict: logger.info("*** Predict ***") - test_output = trainer.predict( - test_dataset=test_dataset, - metric_key_prefix="test", - max_length=data_args.val_max_target_length, - num_beams=data_args.eval_beams, - ) + test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test") metrics = test_output.metrics metrics["test_n_objs"] = data_args.n_test diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 9ca93535e8..b4399c80ed 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from packaging import version from torch import nn -from torch.utils.data import DistributedSampler, RandomSampler from torch.utils.data.dataset import Dataset -from .file_utils import is_torch_tpu_available from .trainer import Trainer -from .trainer_pt_utils import get_tpu_sampler from .trainer_utils import PredictionOutput -from .training_args import ParallelMode from .utils import logging @@ -36,24 +32,6 @@ logger = logging.get_logger(__name__) class Seq2SeqTrainer(Trainer): - def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: - if isinstance(self.train_dataset, torch.utils.data.IterableDataset): - return None - elif is_torch_tpu_available(): - return get_tpu_sampler(self.train_dataset) - else: - if self.args.sortish_sampler: - self.train_dataset.make_sortish_sampler( - self.args.per_device_train_batch_size, - distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED), - ) - - return ( - RandomSampler(self.train_dataset) - if self.args.local_rank == -1 - else DistributedSampler(self.train_dataset) - ) - def evaluate( self, eval_dataset: Optional[Dataset] = None,