From 1b74af76b7e5c259d1470dec9d8d68c303dea5db Mon Sep 17 00:00:00 2001 From: Zhaofeng Wu Date: Tue, 5 Oct 2021 06:04:39 -0700 Subject: [PATCH] Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler (#13820) * Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler * Fix --- src/transformers/trainer.py | 4 +-- src/transformers/trainer_pt_utils.py | 51 +++++++++++++++------------- tests/test_trainer_utils.py | 10 +++--- 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 070f200620..c44960564c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -572,16 +572,16 @@ class Trainer: model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None if self.args.world_size <= 1: return LengthGroupedSampler( - self.train_dataset, self.args.train_batch_size, + dataset=self.train_dataset, lengths=lengths, model_input_name=model_input_name, generator=generator, ) else: return DistributedLengthGroupedSampler( - self.train_dataset, self.args.train_batch_size, + dataset=self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index, lengths=lengths, diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 164f5f2f95..a08c2ddd64 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -520,25 +520,27 @@ class LengthGroupedSampler(Sampler): def __init__( self, - dataset: Dataset, batch_size: int, + dataset: Optional[Dataset] = None, lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None, generator=None, ): - self.dataset = dataset + if dataset is None and lengths is None: + raise ValueError("One of dataset and lengths must be provided.") + self.batch_size = batch_size - self.model_input_name = model_input_name if model_input_name is not None else "input_ids" if lengths is None: + model_input_name = model_input_name if model_input_name is not None else "input_ids" if ( not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) - or self.model_input_name not in dataset[0] + or model_input_name not in dataset[0] ): raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " - f"'{self.model_input_name}' key." + f"'{model_input_name}' key." ) - lengths = [len(feature[self.model_input_name]) for feature in dataset] + lengths = [len(feature[model_input_name]) for feature in dataset] self.lengths = lengths self.generator = generator @@ -558,8 +560,8 @@ class DistributedLengthGroupedSampler(DistributedSampler): # Copied and adapted from PyTorch DistributedSampler. def __init__( self, - dataset: Dataset, batch_size: int, + dataset: Optional[Dataset] = None, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, @@ -567,6 +569,8 @@ class DistributedLengthGroupedSampler(DistributedSampler): lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None, ): + if dataset is None and lengths is None: + raise ValueError("One of dataset and lengths must be provided.") if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -575,37 +579,38 @@ class DistributedLengthGroupedSampler(DistributedSampler): if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() - self.dataset = dataset + self.batch_size = batch_size self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last - # If the dataset length is evenly divisible by # of replicas, then there - # is no need to drop any data, since the dataset will be split equally. - if self.drop_last and len(self.dataset) % self.num_replicas != 0: - # Split to nearest available length that is evenly divisible. - # This is to ensure each rank receives the same amount of data when - # using this Sampler. - self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas) - else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) - self.total_size = self.num_samples * self.num_replicas - self.seed = seed - self.model_input_name = model_input_name if model_input_name is not None else "input_ids" if lengths is None: + model_input_name = model_input_name if model_input_name is not None else "input_ids" if ( not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) - or self.model_input_name not in dataset[0] + or model_input_name not in dataset[0] ): raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " - f"'{self.model_input_name}' key." + f"'{model_input_name}' key." ) - lengths = [len(feature[self.model_input_name]) for feature in dataset] + lengths = [len(feature[model_input_name]) for feature in dataset] self.lengths = lengths + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.lengths) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas) + else: + self.num_samples = math.ceil(len(self.lengths) / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.seed = seed + def __iter__(self) -> Iterator: # Deterministically shuffle based on epoch and seed g = torch.Generator() diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index b6314818f3..8fe8d2e1d2 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -181,7 +181,7 @@ class TrainerUtilsTest(unittest.TestCase): # Put one bigger than the others to check it ends up in first position lengths[32] = 50 - indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths)) + indices = list(LengthGroupedSampler(4, lengths=lengths)) # The biggest element should be first self.assertEqual(lengths[indices[0]], 50) # The indices should be a permutation of range(100) @@ -196,7 +196,7 @@ class TrainerUtilsTest(unittest.TestCase): # Put one bigger than the others to check it ends up in first position data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist() - indices = list(LengthGroupedSampler(data, 4)) + indices = list(LengthGroupedSampler(4, dataset=data)) # The biggest element should be first self.assertEqual(len(data[indices[0]]["input_ids"]), 105) # The indices should be a permutation of range(6) @@ -211,7 +211,7 @@ class TrainerUtilsTest(unittest.TestCase): # Put one bigger than the others to check it ends up in first position data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist() - indices = list(LengthGroupedSampler(data, 4)) + indices = list(LengthGroupedSampler(4, dataset=data)) # The biggest element should be first self.assertEqual(len(data[indices[0]]["input_ids"]), 105) # The indices should be a permutation of range(6) @@ -223,8 +223,8 @@ class TrainerUtilsTest(unittest.TestCase): # Put one bigger than the others to check it ends up in first position lengths[32] = 50 - indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths)) - indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths)) + indices_process_0 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=0, lengths=lengths)) + indices_process_1 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=1, lengths=lengths)) # The biggest element should be first self.assertEqual(lengths[indices_process_0[0]], 50) # The indices should be a permutation of range(100)