From c2cd02ac625cd0ab64cf42124ad71ce9158fb67c Mon Sep 17 00:00:00 2001 From: Takuya Makino Date: Fri, 30 Apr 2021 21:27:46 +0900 Subject: [PATCH] Accepts BatchEncoding in LengthSampler (#11431) --- src/transformers/trainer_pt_utils.py | 11 ++++++++-- tests/test_trainer_utils.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5791ac6c35..62cc1aa480 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available +from .tokenization_utils_base import BatchEncoding from .utils import logging @@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler): self.batch_size = batch_size self.model_input_name = model_input_name if model_input_name is not None else "input_ids" if lengths is None: - if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]: + if ( + not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) + or self.model_input_name not in dataset[0] + ): raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " f"'{self.model_input_name}' key." @@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler): self.model_input_name = model_input_name if model_input_name is not None else "input_ids" if lengths is None: - if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]: + if ( + not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) + or self.model_input_name not in dataset[0] + ): raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " f"'{self.model_input_name}' key." diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index 8ce951703b..b543a1ebca 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -27,6 +27,7 @@ if is_torch_available(): from torch.utils.data import IterableDataset from transformers.modeling_outputs import SequenceClassifierOutput + from transformers.tokenization_utils_base import BatchEncoding from transformers.trainer_pt_utils import ( DistributedLengthGroupedSampler, DistributedSamplerWithLoop, @@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase): # The indices should be a permutation of range(100) self.assertEqual(list(sorted(indices)), list(range(100))) + def test_group_by_length_with_dict(self): + # Get some inputs of random lengths + data = [] + for _ in range(6): + input_ids = torch.randint(0, 25, (100,)).tolist() + data.append({"input_ids": input_ids}) + # Put one bigger than the others to check it ends up in first position + data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist() + + indices = list(LengthGroupedSampler(data, 4)) + # The biggest element should be first + self.assertEqual(len(data[indices[0]]["input_ids"]), 105) + # The indices should be a permutation of range(6) + self.assertEqual(list(sorted(indices)), list(range(6))) + + def test_group_by_length_with_batch_encoding(self): + # Get some inputs of random lengths + data = [] + for _ in range(6): + input_ids = torch.randint(0, 25, (100,)).tolist() + data.append(BatchEncoding({"input_ids": input_ids})) + # Put one bigger than the others to check it ends up in first position + data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist() + + indices = list(LengthGroupedSampler(data, 4)) + # The biggest element should be first + self.assertEqual(len(data[indices[0]]["input_ids"]), 105) + # The indices should be a permutation of range(6) + self.assertEqual(list(sorted(indices)), list(range(6))) + def test_distributed_length_grouped(self): # Get some inputs of random lengths lengths = torch.randint(0, 25, (100,)).tolist()