Accepts BatchEncoding in LengthSampler (#11431)
This commit is contained in:
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
# The indices should be a permutation of range(100)
|
||||
self.assertEqual(list(sorted(indices)), list(range(100)))
|
||||
|
||||
def test_group_by_length_with_dict(self):
|
||||
# Get some inputs of random lengths
|
||||
data = []
|
||||
for _ in range(6):
|
||||
input_ids = torch.randint(0, 25, (100,)).tolist()
|
||||
data.append({"input_ids": input_ids})
|
||||
# Put one bigger than the others to check it ends up in first position
|
||||
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
||||
|
||||
indices = list(LengthGroupedSampler(data, 4))
|
||||
# The biggest element should be first
|
||||
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
||||
# The indices should be a permutation of range(6)
|
||||
self.assertEqual(list(sorted(indices)), list(range(6)))
|
||||
|
||||
def test_group_by_length_with_batch_encoding(self):
|
||||
# Get some inputs of random lengths
|
||||
data = []
|
||||
for _ in range(6):
|
||||
input_ids = torch.randint(0, 25, (100,)).tolist()
|
||||
data.append(BatchEncoding({"input_ids": input_ids}))
|
||||
# Put one bigger than the others to check it ends up in first position
|
||||
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
||||
|
||||
indices = list(LengthGroupedSampler(data, 4))
|
||||
# The biggest element should be first
|
||||
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
||||
# The indices should be a permutation of range(6)
|
||||
self.assertEqual(list(sorted(indices)), list(range(6)))
|
||||
|
||||
def test_distributed_length_grouped(self):
|
||||
# Get some inputs of random lengths
|
||||
lengths = torch.randint(0, 25, (100,)).tolist()
|
||||
|
||||
Reference in New Issue
Block a user