🚨🚨🚨 Replace DataLoader logic for Accelerate in Trainer, remove unneeded tests 🚨🚨🚨 (#24028)
* Working integration * Fix failing test * Revert label host logic * Bring it back!
This commit is contained in:
@@ -61,7 +61,6 @@ from huggingface_hub import Repository, create_repo
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
@@ -73,7 +72,7 @@ from .modelcard import TrainingSummary
|
|||||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||||
from .optimization import Adafactor, get_scheduler
|
from .optimization import Adafactor, get_scheduler
|
||||||
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
|
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
@@ -85,14 +84,11 @@ from .trainer_callback import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from .trainer_pt_utils import (
|
from .trainer_pt_utils import (
|
||||||
DistributedLengthGroupedSampler,
|
|
||||||
DistributedSamplerWithLoop,
|
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
IterableDatasetShard,
|
IterableDatasetShard,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
LengthGroupedSampler,
|
LengthGroupedSampler,
|
||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
ShardSampler,
|
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
find_batch_size,
|
find_batch_size,
|
||||||
@@ -102,7 +98,6 @@ from .trainer_pt_utils import (
|
|||||||
nested_concat,
|
nested_concat,
|
||||||
nested_detach,
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
nested_truncate,
|
|
||||||
nested_xla_mesh_reduce,
|
nested_xla_mesh_reduce,
|
||||||
reissue_pt_warnings,
|
reissue_pt_warnings,
|
||||||
)
|
)
|
||||||
@@ -812,20 +807,6 @@ class Trainer:
|
|||||||
if self.train_dataset is None or not has_length(self.train_dataset):
|
if self.train_dataset is None or not has_length(self.train_dataset):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
generator = None
|
|
||||||
if self.args.world_size <= 1:
|
|
||||||
generator = torch.Generator()
|
|
||||||
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
|
|
||||||
# `args.seed`) if data_seed isn't provided.
|
|
||||||
# Further on in this method, we default to `args.seed` instead.
|
|
||||||
if self.args.data_seed is None:
|
|
||||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
|
||||||
else:
|
|
||||||
seed = self.args.data_seed
|
|
||||||
generator.manual_seed(seed)
|
|
||||||
|
|
||||||
seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
|
|
||||||
|
|
||||||
# Build the sampler.
|
# Build the sampler.
|
||||||
if self.args.group_by_length:
|
if self.args.group_by_length:
|
||||||
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
|
||||||
@@ -837,47 +818,15 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
lengths = None
|
lengths = None
|
||||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||||
if self.args.world_size <= 1:
|
|
||||||
return LengthGroupedSampler(
|
return LengthGroupedSampler(
|
||||||
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
||||||
dataset=self.train_dataset,
|
dataset=self.train_dataset,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
model_input_name=model_input_name,
|
model_input_name=model_input_name,
|
||||||
generator=generator,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return DistributedLengthGroupedSampler(
|
|
||||||
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
|
||||||
dataset=self.train_dataset,
|
|
||||||
num_replicas=self.args.world_size,
|
|
||||||
rank=self.args.process_index,
|
|
||||||
lengths=lengths,
|
|
||||||
model_input_name=model_input_name,
|
|
||||||
seed=seed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.args.world_size <= 1:
|
return RandomSampler(self.train_dataset)
|
||||||
return RandomSampler(self.train_dataset, generator=generator)
|
|
||||||
elif (
|
|
||||||
self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
|
||||||
and not self.args.dataloader_drop_last
|
|
||||||
):
|
|
||||||
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
|
||||||
return DistributedSamplerWithLoop(
|
|
||||||
self.train_dataset,
|
|
||||||
batch_size=self.args.per_device_train_batch_size,
|
|
||||||
num_replicas=self.args.world_size,
|
|
||||||
rank=self.args.process_index,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return DistributedSampler(
|
|
||||||
self.train_dataset,
|
|
||||||
num_replicas=self.args.world_size,
|
|
||||||
rank=self.args.process_index,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
@@ -898,36 +847,19 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||||
|
|
||||||
if isinstance(train_dataset, torch.utils.data.IterableDataset):
|
dataloader_params = {
|
||||||
if self.args.world_size > 1:
|
"batch_size": self._train_batch_size,
|
||||||
train_dataset = IterableDatasetShard(
|
"collate_fn": data_collator,
|
||||||
train_dataset,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
batch_size=self._train_batch_size,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
}
|
||||||
num_processes=self.args.world_size,
|
|
||||||
process_index=self.args.process_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
return DataLoader(
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
train_dataset,
|
dataloader_params["sampler"] = self._get_train_sampler()
|
||||||
batch_size=self._train_batch_size,
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
collate_fn=data_collator,
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_sampler = self._get_train_sampler()
|
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=self._train_batch_size,
|
|
||||||
sampler=train_sampler,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
drop_last=self.args.dataloader_drop_last,
|
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
|
||||||
worker_init_fn=seed_worker,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
||||||
# Deprecated code
|
# Deprecated code
|
||||||
@@ -943,20 +875,13 @@ class Trainer:
|
|||||||
rank=smp.dp_rank(),
|
rank=smp.dp_rank(),
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
batch_size=self.args.per_device_eval_batch_size,
|
||||||
)
|
)
|
||||||
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
|
||||||
return SequentialDistributedSampler(eval_dataset)
|
|
||||||
else:
|
else:
|
||||||
return SequentialSampler(eval_dataset)
|
return SequentialSampler(eval_dataset)
|
||||||
|
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(eval_dataset)
|
return SequentialSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
return ShardSampler(
|
return None
|
||||||
eval_dataset,
|
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
|
||||||
num_processes=self.args.world_size,
|
|
||||||
process_index=self.args.process_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
@@ -979,34 +904,18 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
|
||||||
|
|
||||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
dataloader_params = {
|
||||||
if self.args.world_size > 1:
|
"batch_size": self.args.eval_batch_size,
|
||||||
eval_dataset = IterableDatasetShard(
|
"collate_fn": data_collator,
|
||||||
eval_dataset,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
}
|
||||||
num_processes=self.args.world_size,
|
|
||||||
process_index=self.args.process_index,
|
|
||||||
)
|
|
||||||
return DataLoader(
|
|
||||||
eval_dataset,
|
|
||||||
batch_size=self.args.eval_batch_size,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
||||||
|
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
return DataLoader(
|
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
||||||
eval_dataset,
|
|
||||||
sampler=eval_sampler,
|
|
||||||
batch_size=self.args.eval_batch_size,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
drop_last=self.args.dataloader_drop_last,
|
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
@@ -1026,35 +935,19 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
|
||||||
|
|
||||||
if isinstance(test_dataset, torch.utils.data.IterableDataset):
|
dataloader_params = {
|
||||||
if self.args.world_size > 1:
|
"batch_size": self.args.eval_batch_size,
|
||||||
test_dataset = IterableDatasetShard(
|
"collate_fn": data_collator,
|
||||||
test_dataset,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
batch_size=self.args.eval_batch_size,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
drop_last=self.args.dataloader_drop_last,
|
}
|
||||||
num_processes=self.args.world_size,
|
|
||||||
process_index=self.args.process_index,
|
|
||||||
)
|
|
||||||
return DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
batch_size=self.args.eval_batch_size,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_sampler = self._get_eval_sampler(test_dataset)
|
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
|
||||||
|
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
# We use the same batch_size as for eval.
|
# We use the same batch_size as for eval.
|
||||||
return DataLoader(
|
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
|
||||||
test_dataset,
|
|
||||||
sampler=test_sampler,
|
|
||||||
batch_size=self.args.eval_batch_size,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
drop_last=self.args.dataloader_drop_last,
|
|
||||||
num_workers=self.args.dataloader_num_workers,
|
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
"""
|
"""
|
||||||
@@ -1864,26 +1757,11 @@ class Trainer:
|
|||||||
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
|
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
|
||||||
if not args.ignore_data_skip:
|
if not args.ignore_data_skip:
|
||||||
for epoch in range(epochs_trained):
|
for epoch in range(epochs_trained):
|
||||||
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
|
|
||||||
train_dataloader.sampler, RandomSampler
|
|
||||||
)
|
|
||||||
if is_torch_less_than_1_11 or not is_random_sampler:
|
|
||||||
# We just need to begin an iteration to create the randomization of the sampler.
|
|
||||||
# That was before PyTorch 1.11 however...
|
|
||||||
for _ in train_dataloader:
|
for _ in train_dataloader:
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
# Otherwise we need to call the whooooole sampler cause there is some random operation added
|
|
||||||
# AT THE VERY END!
|
|
||||||
_ = list(train_dataloader.sampler)
|
|
||||||
|
|
||||||
total_batched_samples = 0
|
total_batched_samples = 0
|
||||||
for epoch in range(epochs_trained, num_train_epochs):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
|
||||||
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
|
|
||||||
train_dataloader.dataset.set_epoch(epoch)
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
|
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
|
||||||
epoch_iterator = parallel_loader
|
epoch_iterator = parallel_loader
|
||||||
@@ -3250,27 +3128,29 @@ class Trainer:
|
|||||||
|
|
||||||
# Update containers on host
|
# Update containers on host
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
losses = self._nested_gather(loss.repeat(batch_size))
|
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
|
||||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = self._pad_across_processes(labels)
|
labels = self.accelerator.pad_across_processes(labels)
|
||||||
if inputs_decode is not None:
|
if inputs_decode is not None:
|
||||||
inputs_decode = self._pad_across_processes(inputs_decode)
|
inputs_decode = self.accelerator.pad_across_processes(inputs_decode)
|
||||||
inputs_decode = self._nested_gather(inputs_decode)
|
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
|
||||||
inputs_host = (
|
inputs_host = (
|
||||||
inputs_decode
|
inputs_decode
|
||||||
if inputs_host is None
|
if inputs_host is None
|
||||||
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
|
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
|
||||||
)
|
)
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = self._pad_across_processes(logits)
|
logits = self.accelerator.pad_across_processes(logits)
|
||||||
if self.preprocess_logits_for_metrics is not None:
|
if self.preprocess_logits_for_metrics is not None:
|
||||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||||
logits = self._nested_gather(logits)
|
logits = self.accelerator.gather_for_metrics((logits))
|
||||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = self._nested_gather(labels)
|
labels = self.accelerator.gather_for_metrics((labels))
|
||||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||||
|
|
||||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||||
|
|
||||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||||
@@ -3303,19 +3183,13 @@ class Trainer:
|
|||||||
|
|
||||||
# Gather all remaining tensors and put them back on the CPU
|
# Gather all remaining tensors and put them back on the CPU
|
||||||
if losses_host is not None:
|
if losses_host is not None:
|
||||||
losses = nested_numpify(losses_host)
|
all_losses = nested_numpify(losses_host)
|
||||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
|
||||||
if preds_host is not None:
|
if preds_host is not None:
|
||||||
logits = nested_numpify(preds_host)
|
all_preds = nested_numpify(preds_host)
|
||||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
|
||||||
if inputs_host is not None:
|
if inputs_host is not None:
|
||||||
inputs_decode = nested_numpify(inputs_host)
|
all_inputs = nested_numpify(inputs_host)
|
||||||
all_inputs = (
|
|
||||||
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
|
|
||||||
)
|
|
||||||
if labels_host is not None:
|
if labels_host is not None:
|
||||||
labels = nested_numpify(labels_host)
|
all_labels = nested_numpify(labels_host)
|
||||||
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
|
||||||
|
|
||||||
# Number of samples
|
# Number of samples
|
||||||
if has_length(eval_dataset):
|
if has_length(eval_dataset):
|
||||||
@@ -3332,17 +3206,6 @@ class Trainer:
|
|||||||
if num_samples == 0 and observed_num_examples > 0:
|
if num_samples == 0 and observed_num_examples > 0:
|
||||||
num_samples = observed_num_examples
|
num_samples = observed_num_examples
|
||||||
|
|
||||||
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
|
|
||||||
# samplers has been rounded to a multiple of batch_size, so we truncate.
|
|
||||||
if all_losses is not None:
|
|
||||||
all_losses = all_losses[:num_samples]
|
|
||||||
if all_preds is not None:
|
|
||||||
all_preds = nested_truncate(all_preds, num_samples)
|
|
||||||
if all_labels is not None:
|
|
||||||
all_labels = nested_truncate(all_labels, num_samples)
|
|
||||||
if all_inputs is not None:
|
|
||||||
all_inputs = nested_truncate(all_inputs, num_samples)
|
|
||||||
|
|
||||||
# Metrics!
|
# Metrics!
|
||||||
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
|
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
|
||||||
if args.include_inputs_for_metrics:
|
if args.include_inputs_for_metrics:
|
||||||
|
|||||||
@@ -798,9 +798,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
def test_train_and_eval_dataloaders(self):
|
def test_train_and_eval_dataloaders(self):
|
||||||
n_gpu = max(1, torch.cuda.device_count())
|
n_gpu = max(1, torch.cuda.device_count())
|
||||||
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
|
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
|
||||||
self.assertEqual(trainer.get_train_dataloader().batch_size, 16 * n_gpu)
|
self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16 * n_gpu)
|
||||||
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
|
trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
|
||||||
self.assertEqual(trainer.get_eval_dataloader().batch_size, 16 * n_gpu)
|
self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16 * n_gpu)
|
||||||
|
|
||||||
# Check drop_last works
|
# Check drop_last works
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
@@ -833,67 +833,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
trainer.evaluate()
|
trainer.evaluate()
|
||||||
|
|
||||||
def test_sampler_seed(self):
|
|
||||||
# nb: we don't want to inherit from IterableDataset to hit the right code path
|
|
||||||
class DummyDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, length: int = 101):
|
|
||||||
self.length = length
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def __getitem__(self, i):
|
|
||||||
if (i < 0) or (i >= self.length):
|
|
||||||
raise IndexError
|
|
||||||
return {"input_ids": [i]}
|
|
||||||
|
|
||||||
class DummyModel(PreTrainedModel):
|
|
||||||
def __init__(self, num_params: int):
|
|
||||||
super().__init__(PretrainedConfig())
|
|
||||||
# Add some (unused) params. the point here is that randomness in model_init shouldn't influence
|
|
||||||
# data loader order.
|
|
||||||
self.params = nn.Parameter(torch.randn(num_params))
|
|
||||||
|
|
||||||
def forward(self, input_ids, labels=None):
|
|
||||||
if labels is not None:
|
|
||||||
return torch.tensor(0.0, device=input_ids.device), input_ids
|
|
||||||
else:
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
def _get_first_data_sample(num_params, seed, data_seed, **kwargs):
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
trainer = Trainer(
|
|
||||||
model_init=lambda: DummyModel(num_params),
|
|
||||||
args=TrainingArguments(
|
|
||||||
output_dir=tmpdir,
|
|
||||||
**kwargs,
|
|
||||||
seed=seed,
|
|
||||||
data_seed=data_seed,
|
|
||||||
local_rank=-1,
|
|
||||||
),
|
|
||||||
train_dataset=DummyDataset(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return next(iter(trainer.get_train_dataloader()))
|
|
||||||
|
|
||||||
# test that the seed is passed to the sampler
|
|
||||||
# the codepath we want to hit is world_size <= 1, and both group_by_length
|
|
||||||
for group_by_length in [True, False]:
|
|
||||||
sample42_1 = _get_first_data_sample(num_params=10, seed=42, data_seed=42, group_by_length=group_by_length)
|
|
||||||
sample42_2 = _get_first_data_sample(num_params=11, seed=42, data_seed=42, group_by_length=group_by_length)
|
|
||||||
self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_2["input_ids"]))
|
|
||||||
|
|
||||||
# should get same samples with different seed, so long as data_seed is the same
|
|
||||||
sample42_3 = _get_first_data_sample(num_params=11, seed=11, data_seed=42, group_by_length=group_by_length)
|
|
||||||
self.assertTrue(torch.equal(sample42_1["input_ids"], sample42_3["input_ids"]))
|
|
||||||
|
|
||||||
# make sure we have some randomness in the samples if data_seed is different
|
|
||||||
others = [
|
|
||||||
_get_first_data_sample(num_params=i, seed=42, data_seed=i, group_by_length=group_by_length)
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
self.assertTrue(any(not torch.equal(sample42_1["input_ids"], sample["input_ids"]) for sample in others))
|
|
||||||
|
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
def test_data_is_not_parallelized_when_model_is_parallel(self):
|
||||||
model = RegressionModel()
|
model = RegressionModel()
|
||||||
@@ -907,9 +846,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(trainer.args.n_gpu, 1)
|
self.assertEqual(trainer.args.n_gpu, 1)
|
||||||
|
|
||||||
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
||||||
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
self.assertEqual(trainer.get_train_dataloader().total_batch_size, 16)
|
||||||
self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16)
|
self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16)
|
||||||
self.assertEqual(trainer.get_eval_dataloader().batch_size, 16)
|
self.assertEqual(trainer.get_eval_dataloader().total_batch_size, 16)
|
||||||
self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16)
|
self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16)
|
||||||
|
|
||||||
def test_evaluate(self):
|
def test_evaluate(self):
|
||||||
@@ -1742,26 +1681,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||||
|
|
||||||
def test_training_finite_iterable_dataset(self):
|
|
||||||
config = RegressionModelConfig()
|
|
||||||
model = RegressionPreTrainedModel(config)
|
|
||||||
|
|
||||||
batch_size = 1
|
|
||||||
num_samples = 10
|
|
||||||
|
|
||||||
available_steps = num_samples // batch_size
|
|
||||||
|
|
||||||
data = FiniteIterableDataset(length=num_samples)
|
|
||||||
train_args = TrainingArguments(
|
|
||||||
"..",
|
|
||||||
max_steps=available_steps + 1, # set a higher number than actually available
|
|
||||||
per_device_train_batch_size=batch_size,
|
|
||||||
)
|
|
||||||
trainer = Trainer(model, train_dataset=data, args=train_args)
|
|
||||||
with self.assertLogs("transformers.trainer", level="WARNING") as logs:
|
|
||||||
trainer.train()
|
|
||||||
self.assertIn(f"stopping training at step {available_steps}!", logs.output[0])
|
|
||||||
|
|
||||||
def test_evaluation_iterable_dataset(self):
|
def test_evaluation_iterable_dataset(self):
|
||||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user