Remove subclass for sortish sampler (#9907)
* Remove subclass for sortish sampler * Use old Seq2SeqTrainer in script * Styling
This commit is contained in:
@@ -17,14 +17,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
from .trainer import Trainer
|
||||
from .trainer_pt_utils import get_tpu_sampler
|
||||
from .trainer_utils import PredictionOutput
|
||||
from .training_args import ParallelMode
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -36,24 +32,6 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
return get_tpu_sampler(self.train_dataset)
|
||||
else:
|
||||
if self.args.sortish_sampler:
|
||||
self.train_dataset.make_sortish_sampler(
|
||||
self.args.per_device_train_batch_size,
|
||||
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
|
||||
)
|
||||
|
||||
return (
|
||||
RandomSampler(self.train_dataset)
|
||||
if self.args.local_rank == -1
|
||||
else DistributedSampler(self.train_dataset)
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
|
||||
Reference in New Issue
Block a user