Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler (#13820)
* Allow dataset to be an optional argument for (Distributed)LengthGroupedSampler * Fix
This commit is contained in:
@@ -572,16 +572,16 @@ class Trainer:
|
|||||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return LengthGroupedSampler(
|
return LengthGroupedSampler(
|
||||||
self.train_dataset,
|
|
||||||
self.args.train_batch_size,
|
self.args.train_batch_size,
|
||||||
|
dataset=self.train_dataset,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
model_input_name=model_input_name,
|
model_input_name=model_input_name,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return DistributedLengthGroupedSampler(
|
return DistributedLengthGroupedSampler(
|
||||||
self.train_dataset,
|
|
||||||
self.args.train_batch_size,
|
self.args.train_batch_size,
|
||||||
|
dataset=self.train_dataset,
|
||||||
num_replicas=self.args.world_size,
|
num_replicas=self.args.world_size,
|
||||||
rank=self.args.process_index,
|
rank=self.args.process_index,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
|
|||||||
@@ -520,25 +520,27 @@ class LengthGroupedSampler(Sampler):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
dataset: Optional[Dataset] = None,
|
||||||
lengths: Optional[List[int]] = None,
|
lengths: Optional[List[int]] = None,
|
||||||
model_input_name: Optional[str] = None,
|
model_input_name: Optional[str] = None,
|
||||||
generator=None,
|
generator=None,
|
||||||
):
|
):
|
||||||
self.dataset = dataset
|
if dataset is None and lengths is None:
|
||||||
|
raise ValueError("One of dataset and lengths must be provided.")
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
|
model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||||
if (
|
if (
|
||||||
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
||||||
or self.model_input_name not in dataset[0]
|
or model_input_name not in dataset[0]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||||
f"'{self.model_input_name}' key."
|
f"'{model_input_name}' key."
|
||||||
)
|
)
|
||||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
lengths = [len(feature[model_input_name]) for feature in dataset]
|
||||||
self.lengths = lengths
|
self.lengths = lengths
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
|
|
||||||
@@ -558,8 +560,8 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||||||
# Copied and adapted from PyTorch DistributedSampler.
|
# Copied and adapted from PyTorch DistributedSampler.
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
dataset: Optional[Dataset] = None,
|
||||||
num_replicas: Optional[int] = None,
|
num_replicas: Optional[int] = None,
|
||||||
rank: Optional[int] = None,
|
rank: Optional[int] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
@@ -567,6 +569,8 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||||||
lengths: Optional[List[int]] = None,
|
lengths: Optional[List[int]] = None,
|
||||||
model_input_name: Optional[str] = None,
|
model_input_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
if dataset is None and lengths is None:
|
||||||
|
raise ValueError("One of dataset and lengths must be provided.")
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
@@ -575,37 +579,38 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
self.dataset = dataset
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
# If the dataset length is evenly divisible by # of replicas, then there
|
|
||||||
# is no need to drop any data, since the dataset will be split equally.
|
|
||||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
|
||||||
# Split to nearest available length that is evenly divisible.
|
|
||||||
# This is to ensure each rank receives the same amount of data when
|
|
||||||
# using this Sampler.
|
|
||||||
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
|
|
||||||
else:
|
|
||||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
|
||||||
self.seed = seed
|
|
||||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
|
||||||
|
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
|
model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||||
if (
|
if (
|
||||||
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
||||||
or self.model_input_name not in dataset[0]
|
or model_input_name not in dataset[0]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||||
f"'{self.model_input_name}' key."
|
f"'{model_input_name}' key."
|
||||||
)
|
)
|
||||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
lengths = [len(feature[model_input_name]) for feature in dataset]
|
||||||
self.lengths = lengths
|
self.lengths = lengths
|
||||||
|
|
||||||
|
# If the dataset length is evenly divisible by # of replicas, then there
|
||||||
|
# is no need to drop any data, since the dataset will be split equally.
|
||||||
|
if self.drop_last and len(self.lengths) % self.num_replicas != 0:
|
||||||
|
# Split to nearest available length that is evenly divisible.
|
||||||
|
# This is to ensure each rank receives the same amount of data when
|
||||||
|
# using this Sampler.
|
||||||
|
self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
|
||||||
|
else:
|
||||||
|
self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
def __iter__(self) -> Iterator:
|
||||||
# Deterministically shuffle based on epoch and seed
|
# Deterministically shuffle based on epoch and seed
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
# Put one bigger than the others to check it ends up in first position
|
# Put one bigger than the others to check it ends up in first position
|
||||||
lengths[32] = 50
|
lengths[32] = 50
|
||||||
|
|
||||||
indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths))
|
indices = list(LengthGroupedSampler(4, lengths=lengths))
|
||||||
# The biggest element should be first
|
# The biggest element should be first
|
||||||
self.assertEqual(lengths[indices[0]], 50)
|
self.assertEqual(lengths[indices[0]], 50)
|
||||||
# The indices should be a permutation of range(100)
|
# The indices should be a permutation of range(100)
|
||||||
@@ -196,7 +196,7 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
# Put one bigger than the others to check it ends up in first position
|
# Put one bigger than the others to check it ends up in first position
|
||||||
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
||||||
|
|
||||||
indices = list(LengthGroupedSampler(data, 4))
|
indices = list(LengthGroupedSampler(4, dataset=data))
|
||||||
# The biggest element should be first
|
# The biggest element should be first
|
||||||
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
||||||
# The indices should be a permutation of range(6)
|
# The indices should be a permutation of range(6)
|
||||||
@@ -211,7 +211,7 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
# Put one bigger than the others to check it ends up in first position
|
# Put one bigger than the others to check it ends up in first position
|
||||||
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()
|
||||||
|
|
||||||
indices = list(LengthGroupedSampler(data, 4))
|
indices = list(LengthGroupedSampler(4, dataset=data))
|
||||||
# The biggest element should be first
|
# The biggest element should be first
|
||||||
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
|
||||||
# The indices should be a permutation of range(6)
|
# The indices should be a permutation of range(6)
|
||||||
@@ -223,8 +223,8 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
# Put one bigger than the others to check it ends up in first position
|
# Put one bigger than the others to check it ends up in first position
|
||||||
lengths[32] = 50
|
lengths[32] = 50
|
||||||
|
|
||||||
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths))
|
indices_process_0 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=0, lengths=lengths))
|
||||||
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths))
|
indices_process_1 = list(DistributedLengthGroupedSampler(4, num_replicas=2, rank=1, lengths=lengths))
|
||||||
# The biggest element should be first
|
# The biggest element should be first
|
||||||
self.assertEqual(lengths[indices_process_0[0]], 50)
|
self.assertEqual(lengths[indices_process_0[0]], 50)
|
||||||
# The indices should be a permutation of range(100)
|
# The indices should be a permutation of range(100)
|
||||||
|
|||||||
Reference in New Issue
Block a user