Accepts BatchEncoding in LengthSampler (#11431)
This commit is contained in:
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||||
|
|
||||||
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
|
from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
|
||||||
|
from .tokenization_utils_base import BatchEncoding
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -514,7 +515,10 @@ class LengthGroupedSampler(Sampler):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
|
if (
|
||||||
|
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
||||||
|
or self.model_input_name not in dataset[0]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||||
f"'{self.model_input_name}' key."
|
f"'{self.model_input_name}' key."
|
||||||
@@ -575,7 +579,10 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||||
|
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
|
if (
|
||||||
|
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
||||||
|
or self.model_input_name not in dataset[0]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||||
f"'{self.model_input_name}' key."
|
f"'{self.model_input_name}' key."
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
|||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
from transformers.trainer_pt_utils import (
|
from transformers.trainer_pt_utils import (
|
||||||
DistributedLengthGroupedSampler,
|
DistributedLengthGroupedSampler,
|
||||||
DistributedSamplerWithLoop,
|
DistributedSamplerWithLoop,
|
||||||
@@ -185,6 +186,36 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
# The indices should be a permutation of range(100)
|
# The indices should be a permutation of range(100)
|
||||||
self.assertEqual(list(sorted(indices)), list(range(100)))
|
self.assertEqual(list(sorted(indices)), list(range(100)))
|
||||||
|
|
||||||
|
def test_group_by_length_with_dict(self):
|
||||||
|
# Get some inputs of random lengths
|
||||||
|
data = []
|
||||||
|
for _ in range(6):
|
||||||
|
input_ids = torch.randint(0, 25, (100,)).tolist()
|
||||||
|
data.append({"input_ids": input_ids})
|
||||||
|
# Put one bigger than the others to check it ends up in first position
|
||||||
|
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
||||||
|
|
||||||
|
indices = list(LengthGroupedSampler(data, 4))
|
||||||
|
# The biggest element should be first
|
||||||
|
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
||||||
|
# The indices should be a permutation of range(6)
|
||||||
|
self.assertEqual(list(sorted(indices)), list(range(6)))
|
||||||
|
|
||||||
|
def test_group_by_length_with_batch_encoding(self):
|
||||||
|
# Get some inputs of random lengths
|
||||||
|
data = []
|
||||||
|
for _ in range(6):
|
||||||
|
input_ids = torch.randint(0, 25, (100,)).tolist()
|
||||||
|
data.append(BatchEncoding({"input_ids": input_ids}))
|
||||||
|
# Put one bigger than the others to check it ends up in first position
|
||||||
|
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
||||||
|
|
||||||
|
indices = list(LengthGroupedSampler(data, 4))
|
||||||
|
# The biggest element should be first
|
||||||
|
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
||||||
|
# The indices should be a permutation of range(6)
|
||||||
|
self.assertEqual(list(sorted(indices)), list(range(6)))
|
||||||
|
|
||||||
def test_distributed_length_grouped(self):
|
def test_distributed_length_grouped(self):
|
||||||
# Get some inputs of random lengths
|
# Get some inputs of random lengths
|
||||||
lengths = torch.randint(0, 25, (100,)).tolist()
|
lengths = torch.randint(0, 25, (100,)).tolist()
|
||||||
|
|||||||
Reference in New Issue
Block a user