From d9c62047a8d75e18d2849d345ab3394875a712ef Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 16 Apr 2021 16:01:58 -0400 Subject: [PATCH] Trainer support for IterableDataset for evaluation and predict (#11286) * Bulk of the work * Polish and tests * Update QA Trainer * Avoid breaking the predict method * Deprecation warnings * Store real eval dataloder * Get eval dataset reference before wrap --- src/transformers/trainer.py | 560 ++++++++++++++++++++------- src/transformers/trainer_callback.py | 4 +- src/transformers/trainer_pt_utils.py | 80 ++++ src/transformers/trainer_utils.py | 7 + src/transformers/training_args.py | 3 + src/transformers/utils/notebook.py | 3 + tests/test_trainer.py | 59 ++- tests/test_trainer_utils.py | 61 ++- 8 files changed, 608 insertions(+), 169 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cab3bbb246..f3fd3e232a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -49,7 +49,7 @@ import torch from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset +from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, SequentialSampler @@ -85,18 +85,22 @@ from .trainer_pt_utils import ( LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, + ShardSampler, distributed_broadcast_scalars, distributed_concat, + find_batch_size, get_parameter_names, nested_concat, nested_detach, nested_numpify, + nested_truncate, nested_xla_mesh_reduce, reissue_pt_warnings, ) from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, + EvalLoopOutput, EvalPrediction, HPSearchBackend, PredictionOutput, @@ -381,11 +385,8 @@ class Trainer: if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") - # Enforce rules on using datasets with no __len__ if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") - if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): - raise ValueError("eval_dataset must implement __len__") self._signature_columns = None if is_datasets_available(): @@ -591,19 +592,33 @@ class Trainer: ) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: - if is_torch_tpu_available(): - return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - elif is_sagemaker_mp_enabled(): - return SequentialDistributedSampler( - eval_dataset, - num_replicas=smp.dp_size(), - rank=smp.dp_rank(), - batch_size=self.args.per_device_eval_batch_size, - ) - elif self.args.local_rank != -1: - return SequentialDistributedSampler(eval_dataset) - else: + # Deprecated code + if self.args.use_legacy_prediction_loop: + if is_torch_tpu_available(): + return SequentialDistributedSampler( + eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) + elif self.args.local_rank != -1: + return SequentialDistributedSampler(eval_dataset) + else: + return SequentialSampler(eval_dataset) + + if self.args.world_size <= 1: return SequentialSampler(eval_dataset) + else: + return ShardSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: """ @@ -618,11 +633,27 @@ class Trainer: """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") - elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): - raise ValueError("eval_dataset must implement __len__") elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): self._remove_unused_columns(eval_dataset, description="evaluation") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset): + if self.args.world_size > 1: + eval_dataset = IterableDatasetShard( + eval_dataset, + batch_size=self.args.eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + eval_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + eval_sampler = self._get_eval_sampler(eval_dataset) return DataLoader( @@ -646,10 +677,26 @@ class Trainer: The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`. """ - if not isinstance(test_dataset, collections.abc.Sized): - raise ValueError("test_dataset must implement __len__") - elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): self._remove_unused_columns(test_dataset, description="test") + + if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset): + if self.args.world_size > 1: + test_dataset = IterableDatasetShard( + test_dataset, + batch_size=self.args.eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + test_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + test_sampler = self._get_eval_sampler(test_dataset) # We use the same batch_size as for eval. @@ -983,7 +1030,7 @@ class Trainer: else: # see __init__. max_steps is set when the dataset has no __len__ max_steps = self.args.max_steps - num_train_epochs = 1 + num_train_epochs = int(self.args.num_train_epochs) num_update_steps_per_epoch = max_steps delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE @@ -1794,13 +1841,11 @@ class Trainer: # memory metrics - must set up as early as possible self._memory_tracker.start() - if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): - raise ValueError("eval_dataset must implement __len__") - eval_dataloader = self.get_eval_dataloader(eval_dataset) start_time = time.time() - output = self.prediction_loop( + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to @@ -1810,8 +1855,7 @@ class Trainer: metric_key_prefix=metric_key_prefix, ) - n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset) - output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples)) + output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples)) self.log(output.metrics) if self.args.tpu_metrics_debug or self.args.debug: @@ -1860,20 +1904,352 @@ class Trainer: # memory metrics - must set up as early as possible self._memory_tracker.start() - if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized): - raise ValueError("test_dataset must implement __len__") - test_dataloader = self.get_test_dataloader(test_dataset) start_time = time.time() - output = self.prediction_loop( + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix ) - output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset))) + output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples)) self._memory_tracker.stop_and_update_metrics(output.metrics) - return output + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. + + Works both with or without labels. + """ + prediction_loss_only = ( + prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only + ) + + # if eval is called w/o train init deepspeed here + if self.args.deepspeed and not self.deepspeed: + + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since + # for example the Z3-optimizer is a must for zero3 to work even for inference - what we + # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer + deepspeed_engine.optimizer.optimizer = None + deepspeed_engine.lr_scheduler = None + + model = self._wrap_model(self.model, training=False) + + # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while + # ``train`` is running, half it first and then put on device + if not self.is_in_train and self.args.fp16_full_eval: + model = model.half().to(self.args.device) + + batch_size = dataloader.batch_size + + logger.info(f"***** Running {description} *****") + if isinstance(dataloader.dataset, collections.abc.Sized): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = dataloader.dataset + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) + + if self.args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + # Will be useful when we have an iterable dataset so don't know its length. + + observed_num_examples = 0 + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + + # Update containers on host + if loss is not None: + losses = self._nested_gather(loss.repeat(batch_size)) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if logits is not None: + logits = self._pad_across_processes(logits) + logits = self._nested_gather(logits) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels = self._pad_across_processes(labels) + labels = self._nested_gather(labels) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = ( + labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + ) + + # Set back to None to begin a new accumulation + losses_host, preds_host, labels_host = None, None, None + + if self.args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Number of samples + if not isinstance(eval_dataset, IterableDataset): + num_samples = len(eval_dataset) + elif isinstance(eval_dataset, IterableDatasetShard): + num_samples = eval_dataset.num_examples + else: + num_samples = observed_num_examples + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + + # Metrics! + if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if all_losses is not None: + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + if name is None: + name = "nested_gather" + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.local_rank != -1: + tensors = distributed_concat(tensors) + return tensors + + # Copied from Accelerate. + def _pad_across_processes(self, tensor, pad_index=-100): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so + they can safely be gathered. + """ + if isinstance(tensor, (list, tuple)): + return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + + if len(tensor.shape) < 2: + return tensor + # Gather all sizes + size = torch.tensor(tensor.shape, device=tensor.device)[None] + sizes = self._nested_gather(size).cpu() + + max_size = max(s[1] for s in sizes) + if tensor.shape[1] == max_size: + return tensor + + # Then pad to the maximum size + old_size = tensor.shape + new_size = list(old_size) + new_size[1] = max_size + new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + new_tensor[:, : old_size[1]] = tensor + return new_tensor + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on :obj:`model` using obj:`inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (:obj:`bool`): + Whether or not to return the loss only. + ignore_keys (:obj:`Lst[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = all(inputs.get(k) is not None for k in self.label_names) + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels: + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + + return (loss, logits, labels) + + def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of + floating point operations for every backward + forward pass. If using another model, either implement such a + method in the model or subclass and override this method. + + Args: + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + :obj:`int`: The number of floating-point operations. + """ + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) + else: + return 0 + + # + # Deprecated code + # def prediction_loop( self, @@ -2015,119 +2391,3 @@ class Trainer: tensors = distributed_concat(tensors) return nested_numpify(tensors) - - def prediction_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Perform an evaluation step on :obj:`model` using obj:`inputs`. - - Subclass and override to inject custom behavior. - - Args: - model (:obj:`nn.Module`): - The model to evaluate. - inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument :obj:`labels`. Check your model's documentation for all accepted arguments. - prediction_loss_only (:obj:`bool`): - Whether or not to return the loss only. - ignore_keys (:obj:`Lst[str]`, `optional`): - A list of keys in the output of your model (if it is a dictionary) that should be ignored when - gathering predictions. - - Return: - Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, - logits and labels (each being optional). - """ - has_labels = all(inputs.get(k) is not None for k in self.label_names) - inputs = self._prepare_inputs(inputs) - if ignore_keys is None: - if hasattr(self.model, "config"): - ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. - if has_labels: - labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) - if len(labels) == 1: - labels = labels[0] - else: - labels = None - - with torch.no_grad(): - if is_sagemaker_mp_enabled(): - raw_outputs = smp_forward_only(model, inputs) - if has_labels: - if isinstance(raw_outputs, dict): - loss_mb = raw_outputs["loss"] - logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) - else: - loss_mb = raw_outputs[0] - logits_mb = raw_outputs[1:] - - loss = loss_mb.reduce_mean().detach().cpu() - logits = smp_nested_concat(logits_mb) - else: - loss = None - if isinstance(raw_outputs, dict): - logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) - else: - logits_mb = raw_outputs - logits = smp_nested_concat(logits_mb) - else: - if has_labels: - loss, outputs = self.compute_loss(model, inputs, return_outputs=True) - loss = loss.mean().detach() - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) - else: - logits = outputs[1:] - else: - loss = None - if self.use_amp: - with autocast(): - outputs = model(**inputs) - else: - outputs = model(**inputs) - if isinstance(outputs, dict): - logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) - else: - logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index - 1] - - if prediction_loss_only: - return (loss, None, None) - - logits = nested_detach(logits) - if len(logits) == 1: - logits = logits[0] - - return (loss, logits, labels) - - def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): - """ - For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of - floating point operations for every backward + forward pass. If using another model, either implement such a - method in the model or subclass and override this method. - - Args: - inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - Returns: - :obj:`int`: The number of floating-point operations. - """ - if hasattr(self.model, "floating_point_ops"): - return self.model.floating_point_ops(inputs) - else: - return 0 diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 151dbf52a0..e760ab55c1 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -15,7 +15,7 @@ """ Callbacks to use with the Trainer class and customize the training loop. """ - +import collections import dataclasses import json from dataclasses import dataclass @@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback): self.current_step = state.global_step def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): - if state.is_local_process_zero: + if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized): if self.prediction_bar is None: self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) self.prediction_bar.update(1) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index c81f98c744..0b58904c00 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100): raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") +def find_batch_size(tensors): + """ + Find the first dimension of a tensor in a nested list/tuple/dict of tensors. + """ + if isinstance(tensors, (list, tuple)): + for t in tensors: + result = find_batch_size(t) + if result is not None: + return result + elif isinstance(tensors, dict): + for key, value in tensors.items(): + result = find_batch_size(value) + if result is not None: + return result + elif isinstance(tensors, torch.Tensor): + return tensors.shape[0] if len(tensors.shape) >= 1 else None + elif isinstance(tensors, np.ndarray): + return tensors.shape[0] if len(tensors.shape) >= 1 else None + + def nested_numpify(tensors): "Numpify `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)): @@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler): """ def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): + warnings.warn( + "SequentialDistributedSampler is deprecated and will be removed in v5 of Tranformers.", + FutureWarning, + ) if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -338,6 +362,10 @@ class DistributedTensorGatherer: """ def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100): + warnings.warn( + "DistributedTensorGatherer is deprecated and will be removed in v5 of Tranformers.", + FutureWarning, + ) self.world_size = world_size self.num_samples = num_samples total_size = world_size if make_multiple_of is None else world_size * make_multiple_of @@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler): return iter(indices) +class ShardSampler(Sampler): + """ + Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch + size 4, the first two batches are :obj:`[0, 1, 2, 3, 4, 5, 6, 7]` and :obj:`[8, 9, 10, 11, 12, 13, 14, 15]`, which + shard into :obj:`[0, 1, 2, 3]` and :obj:`[8, 9, 10, 11]` for GPU-0 and :obj:`[4, 5, 6, 7]` and :obj:`[12, 13, 14, + 15]` for GPU-1. + + The sampler thus yields :obj:`[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and :obj:`[4, 5, 6, 7, 12, 13, 14, 15]` on + GPU-1. + """ + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + drop_last: bool = False, + num_processes: int = 1, + process_index: int = 0, + ): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.num_processes = num_processes + self.process_index = process_index + + self.total_batch_size = total_batch_size = batch_size * num_processes + + num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size) + self.total_num_samples = num_batches * total_batch_size + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset + # and it needs to be done several times. + while len(indices) < self.total_num_samples: + indices += indices[: (self.total_num_samples - len(indices))] + + result = [] + for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size): + result += indices[batch_start : batch_start + self.batch_size] + + return iter(result) + + def __len__(self): + # Each shard only sees a fraction of total_num_samples. + return self.total_num_samples // self.num_processes + + class IterableDatasetShard(IterableDataset): """ Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class @@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset): self.process_index = process_index self.seed = seed self.epoch = 0 + self.num_examples = 0 def set_epoch(self, epoch): self.epoch = epoch @@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset): self.dataset.set_epoch(epoch) def __iter__(self): + self.num_examples = 0 if ( not hasattr(self.dataset, "set_epoch") and hasattr(self.dataset, "generator") @@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset): first_batch = None current_batch = [] for element in self.dataset: + self.num_examples += 1 current_batch.append(element) # Wait to have a full batch before yielding elements. if len(current_batch) == real_batch_size: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 71df8bc8de..53d2cf7f15 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple): label_ids: np.ndarray +class EvalLoopOutput(NamedTuple): + predictions: Union[np.ndarray, Tuple[np.ndarray]] + label_ids: Optional[np.ndarray] + metrics: Optional[Dict[str, float]] + num_samples: Optional[int] + + class PredictionOutput(NamedTuple): predictions: Union[np.ndarray, Tuple[np.ndarray]] label_ids: Optional[np.ndarray] diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 772c24bc2d..be98825e22 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -524,6 +524,9 @@ class TrainingArguments: skip_memory_metrics: bool = field( default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."} ) + use_legacy_prediction_loop: bool = field( + default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."} + ) _n_gpu: int = field(init=False, repr=False, default=-1) mp_parameters: str = field( default="", diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 91e85a5d7a..bcc67bef40 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import time from typing import Optional @@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback): self._force_next_update = False def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if not isinstance(eval_dataloader.dataset, collections.abc.Sized): + return if self.prediction_bar is None: if self.training_tracker is not None: self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader)) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 53f5f0b1ca..f3ebf14a87 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ) self.assertEqual(len(dataset), 31) - def test_trainer_iterable_dataset(self): + def test_training_iterable_dataset(self): config = RegressionModelConfig() model = RegressionPreTrainedModel(config) train_dataset = SampleIterableDataset() - args = RegressionTrainingArguments(output_dir="./examples", max_steps=2) + args = RegressionTrainingArguments(output_dir="./examples", max_steps=4) trainer = Trainer(model=model, args=args, train_dataset=train_dataset) trainer.train() + self.assertEqual(trainer.state.global_step, 4) loader = trainer.get_train_dataloader() self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler) - # Exception if giving iterable dataset and no max_steps - with self.assertRaises(ValueError): - args1 = RegressionTrainingArguments(output_dir="./examples") - _ = Trainer(model=model, args=args1, train_dataset=train_dataset) + def test_evaluation_iterable_dataset(self): + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() - # Exception if eval_dataset is iterable in __init__ - with self.assertRaises(ValueError): - _ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset) + args = RegressionTrainingArguments(output_dir="./examples") + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy()) + results = trainer.evaluate() - # Exception if predicting with iterable dataset - with self.assertRaises(ValueError): - trainer.predict(train_dataset) + x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) - # Exception if evaluating with iterable dataset - with self.assertRaises(ValueError): - trainer.evaluate(train_dataset) + # With a number of elements not a round multiple of the batch size + eval_dataset = SampleIterableDataset(length=66) + results = trainer.evaluate(eval_dataset) + + x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0] + pred = 1.5 * x + 2.5 + expected_loss = ((pred - y) ** 2).mean() + self.assertAlmostEqual(results["eval_loss"], expected_loss) + expected_acc = AlmostAccuracy()((pred, y))["accuracy"] + self.assertAlmostEqual(results["eval_accuracy"], expected_acc) + + def test_predict_iterable_dataset(self): + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + args = RegressionTrainingArguments(output_dir="./examples") + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy()) + + preds = trainer.predict(trainer.eval_dataset).predictions + x = eval_dataset.dataset.x + self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) + + # With a number of elements not a round multiple of the batch size + test_dataset = SampleIterableDataset(length=66) + preds = trainer.predict(test_dataset).predictions + x = test_dataset.dataset.x + self.assertTrue(np.allclose(preds, 1.5 * x + 2.5)) def test_num_train_epochs_in_training(self): # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given. diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index 8657a9e640..8ce951703b 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest import numpy as np @@ -34,6 +35,7 @@ if is_torch_available(): LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, + ShardSampler, get_parameter_names, ) @@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase): # All shards have the same number of samples self.assertEqual(len(shard), len(shard_lists[0])) + for shard in shards: + # All shards know the total number of samples + self.assertEqual(shard.num_examples, len(reference)) + observed = [] for idx in range(0, len(shard_lists[0]), batch_size): for shard in shard_lists: @@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase): reference += reference self.assertListEqual(observed, reference[: len(observed)]) + # Check equivalence between IterableDataset and ShardSampler + dataset.generator.manual_seed(epoch) + reference = list(dataset) + + sampler_shards = [ + ShardSampler( + reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i + ) + for i in range(num_processes) + ] + for shard, sampler_shard in zip(shard_lists, sampler_shards): + self.assertListEqual(shard, list(sampler_shard)) + def test_iterable_dataset_shard(self): dataset = RandomIterableDataset() self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) - self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) + self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0) self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) - self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) + self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42) + + def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2): + shards = [ + ShardSampler( + dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i + ) + for i in range(num_processes) + ] + shard_lists = [list(shard) for shard in shards] + + for shard in shard_lists: + # All shards have a number of samples that is a round multiple of batch size + self.assertTrue(len(shard) % batch_size == 0) + # All shards have the same number of samples + self.assertEqual(len(shard), len(shard_lists[0])) + + observed = [] + for idx in range(0, len(shard_lists[0]), batch_size): + for shard in shard_lists: + observed += shard[idx : idx + batch_size] + + # If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of + # batch_size + reference = copy.copy(dataset) + if not drop_last: + while len(reference) < len(observed): + reference += reference + self.assertListEqual(observed, reference[: len(observed)]) + + def test_shard_sampler(self): + for n_elements in [64, 123]: + dataset = list(range(n_elements)) + + self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2) + self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2) + + self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3) + self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)