Remove subclass for sortish sampler (#9907)

* Remove subclass for sortish sampler

* Use old Seq2SeqTrainer in script

* Styling
This commit is contained in:
Sylvain Gugger
2021-02-01 08:06:32 -05:00
committed by GitHub
parent 1682804ebd
commit 115d97dd2f
2 changed files with 5 additions and 33 deletions

View File

@@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from torch.utils.data.dataset import Dataset
from .file_utils import is_torch_tpu_available
from .trainer import Trainer
from .trainer_pt_utils import get_tpu_sampler
from .trainer_utils import PredictionOutput
from .training_args import ParallelMode
from .utils import logging
@@ -36,24 +32,6 @@ logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,