Trainer support for IterableDataset for evaluation and predict (#11286)
* Bulk of the work * Polish and tests * Update QA Trainer * Avoid breaking the predict method * Deprecation warnings * Store real eval dataloder * Get eval dataset reference before wrap
This commit is contained in:
@@ -49,7 +49,7 @@ import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
@@ -85,18 +85,22 @@ from .trainer_pt_utils import (
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
ShardSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
find_batch_size,
|
||||
get_parameter_names,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
nested_truncate,
|
||||
nested_xla_mesh_reduce,
|
||||
reissue_pt_warnings,
|
||||
)
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalLoopOutput,
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
@@ -381,11 +385,8 @@ class Trainer:
|
||||
if args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
# Enforce rules on using datasets with no __len__
|
||||
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
|
||||
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
|
||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
self._signature_columns = None
|
||||
if is_datasets_available():
|
||||
@@ -591,8 +592,12 @@ class Trainer:
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
# Deprecated code
|
||||
if self.args.use_legacy_prediction_loop:
|
||||
if is_torch_tpu_available():
|
||||
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
@@ -605,6 +610,16 @@ class Trainer:
|
||||
else:
|
||||
return SequentialSampler(eval_dataset)
|
||||
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(eval_dataset)
|
||||
else:
|
||||
return ShardSampler(
|
||||
eval_dataset,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
"""
|
||||
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
||||
@@ -618,11 +633,27 @@ class Trainer:
|
||||
"""
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
if isinstance(eval_dataset, torch.utils.data.dataset.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
eval_dataset = IterableDatasetShard(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
)
|
||||
return DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
|
||||
return DataLoader(
|
||||
@@ -646,10 +677,26 @@ class Trainer:
|
||||
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||
"""
|
||||
if not isinstance(test_dataset, collections.abc.Sized):
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||
self._remove_unused_columns(test_dataset, description="test")
|
||||
|
||||
if isinstance(test_dataset, torch.utils.data.dataset.IterableDataset):
|
||||
if self.args.world_size > 1:
|
||||
test_dataset = IterableDatasetShard(
|
||||
test_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
num_processes=self.args.world_size,
|
||||
process_index=self.args.process_index,
|
||||
)
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
collate_fn=self.data_collator,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
pin_memory=self.args.dataloader_pin_memory,
|
||||
)
|
||||
|
||||
test_sampler = self._get_eval_sampler(test_dataset)
|
||||
|
||||
# We use the same batch_size as for eval.
|
||||
@@ -983,7 +1030,7 @@ class Trainer:
|
||||
else:
|
||||
# see __init__. max_steps is set when the dataset has no __len__
|
||||
max_steps = self.args.max_steps
|
||||
num_train_epochs = 1
|
||||
num_train_epochs = int(self.args.num_train_epochs)
|
||||
num_update_steps_per_epoch = max_steps
|
||||
|
||||
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
@@ -1794,13 +1841,11 @@ class Trainer:
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker.start()
|
||||
|
||||
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
start_time = time.time()
|
||||
|
||||
output = self.prediction_loop(
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
output = eval_loop(
|
||||
eval_dataloader,
|
||||
description="Evaluation",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
@@ -1810,8 +1855,7 @@ class Trainer:
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
|
||||
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
||||
self.log(output.metrics)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
@@ -1860,20 +1904,352 @@ class Trainer:
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker.start()
|
||||
|
||||
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
start_time = time.time()
|
||||
|
||||
output = self.prediction_loop(
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
output = eval_loop(
|
||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
||||
|
||||
return output
|
||||
return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
|
||||
|
||||
def evaluation_loop(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
prediction_loss_only = (
|
||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||
)
|
||||
|
||||
# if eval is called w/o train init deepspeed here
|
||||
if self.args.deepspeed and not self.deepspeed:
|
||||
|
||||
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
|
||||
# from the checkpoint eventually
|
||||
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
|
||||
self.model = deepspeed_engine.module
|
||||
self.model_wrapped = deepspeed_engine
|
||||
self.deepspeed = deepspeed_engine
|
||||
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
|
||||
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
|
||||
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
|
||||
deepspeed_engine.optimizer.optimizer = None
|
||||
deepspeed_engine.lr_scheduler = None
|
||||
|
||||
model = self._wrap_model(self.model, training=False)
|
||||
|
||||
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
|
||||
# ``train`` is running, half it first and then put on device
|
||||
if not self.is_in_train and self.args.fp16_full_eval:
|
||||
model = model.half().to(self.args.device)
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
|
||||
logger.info(f"***** Running {description} *****")
|
||||
if isinstance(dataloader.dataset, collections.abc.Sized):
|
||||
logger.info(f" Num examples = {self.num_examples(dataloader)}")
|
||||
else:
|
||||
logger.info(" Num examples: Unknown")
|
||||
logger.info(f" Batch size = {batch_size}")
|
||||
|
||||
model.eval()
|
||||
|
||||
self.callback_handler.eval_dataloader = dataloader
|
||||
# Do this before wrapping.
|
||||
eval_dataset = dataloader.dataset
|
||||
|
||||
if is_torch_tpu_available():
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
# Initialize containers
|
||||
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
|
||||
losses_host = None
|
||||
preds_host = None
|
||||
labels_host = None
|
||||
# losses/preds/labels on CPU (final containers)
|
||||
all_losses = None
|
||||
all_preds = None
|
||||
all_labels = None
|
||||
# Will be useful when we have an iterable dataset so don't know its length.
|
||||
|
||||
observed_num_examples = 0
|
||||
# Main evaluation loop
|
||||
for step, inputs in enumerate(dataloader):
|
||||
# Update the observed num examples
|
||||
observed_batch_size = find_batch_size(inputs)
|
||||
if observed_batch_size is not None:
|
||||
observed_num_examples += observed_batch_size
|
||||
|
||||
# Prediction step
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||
|
||||
# Update containers on host
|
||||
if loss is not None:
|
||||
losses = self._nested_gather(loss.repeat(batch_size))
|
||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
||||
if logits is not None:
|
||||
logits = self._pad_across_processes(logits)
|
||||
logits = self._nested_gather(logits)
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||
if labels is not None:
|
||||
labels = self._pad_across_processes(labels)
|
||||
labels = self._nested_gather(labels)
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
||||
if losses_host is not None:
|
||||
losses = nested_numpify(losses_host)
|
||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
||||
if preds_host is not None:
|
||||
logits = nested_numpify(preds_host)
|
||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
||||
if labels_host is not None:
|
||||
labels = nested_numpify(labels_host)
|
||||
all_labels = (
|
||||
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
||||
)
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
# Gather all remaining tensors and put them back on the CPU
|
||||
if losses_host is not None:
|
||||
losses = nested_numpify(losses_host)
|
||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
||||
if preds_host is not None:
|
||||
logits = nested_numpify(preds_host)
|
||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
||||
if labels_host is not None:
|
||||
labels = nested_numpify(labels_host)
|
||||
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
||||
|
||||
# Number of samples
|
||||
if not isinstance(eval_dataset, IterableDataset):
|
||||
num_samples = len(eval_dataset)
|
||||
elif isinstance(eval_dataset, IterableDatasetShard):
|
||||
num_samples = eval_dataset.num_examples
|
||||
else:
|
||||
num_samples = observed_num_examples
|
||||
|
||||
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
|
||||
# samplers has been rounded to a multiple of batch_size, so we truncate.
|
||||
if all_losses is not None:
|
||||
all_losses = all_losses[:num_samples]
|
||||
if all_preds is not None:
|
||||
all_preds = nested_truncate(all_preds, num_samples)
|
||||
if all_labels is not None:
|
||||
all_labels = nested_truncate(all_labels, num_samples)
|
||||
|
||||
# Metrics!
|
||||
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||
metrics = denumpify_detensorize(metrics)
|
||||
|
||||
if all_losses is not None:
|
||||
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
||||
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
|
||||
|
||||
def _nested_gather(self, tensors, name=None):
|
||||
"""
|
||||
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
|
||||
concatenating them to `gathered`
|
||||
"""
|
||||
if tensors is None:
|
||||
return
|
||||
if is_torch_tpu_available():
|
||||
if name is None:
|
||||
name = "nested_gather"
|
||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
tensors = smp_gather(tensors)
|
||||
elif self.args.local_rank != -1:
|
||||
tensors = distributed_concat(tensors)
|
||||
return tensors
|
||||
|
||||
# Copied from Accelerate.
|
||||
def _pad_across_processes(self, tensor, pad_index=-100):
|
||||
"""
|
||||
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
|
||||
they can safely be gathered.
|
||||
"""
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
|
||||
elif isinstance(tensor, dict):
|
||||
return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
|
||||
elif not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||
)
|
||||
|
||||
if len(tensor.shape) < 2:
|
||||
return tensor
|
||||
# Gather all sizes
|
||||
size = torch.tensor(tensor.shape, device=tensor.device)[None]
|
||||
sizes = self._nested_gather(size).cpu()
|
||||
|
||||
max_size = max(s[1] for s in sizes)
|
||||
if tensor.shape[1] == max_size:
|
||||
return tensor
|
||||
|
||||
# Then pad to the maximum size
|
||||
old_size = tensor.shape
|
||||
new_size = list(old_size)
|
||||
new_size[1] = max_size
|
||||
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
|
||||
new_tensor[:, : old_size[1]] = tensor
|
||||
return new_tensor
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
||||
logits and labels (each being optional).
|
||||
"""
|
||||
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
||||
if has_labels:
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
with torch.no_grad():
|
||||
if is_sagemaker_mp_enabled():
|
||||
raw_outputs = smp_forward_only(model, inputs)
|
||||
if has_labels:
|
||||
if isinstance(raw_outputs, dict):
|
||||
loss_mb = raw_outputs["loss"]
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
loss_mb = raw_outputs[0]
|
||||
logits_mb = raw_outputs[1:]
|
||||
|
||||
loss = loss_mb.reduce_mean().detach().cpu()
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
loss = None
|
||||
if isinstance(raw_outputs, dict):
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits_mb = raw_outputs
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
if has_labels:
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
outputs = model(**inputs)
|
||||
else:
|
||||
outputs = model(**inputs)
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits = outputs
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index - 1]
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = nested_detach(logits)
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
||||
"""
|
||||
For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
|
||||
floating point operations for every backward + forward pass. If using another model, either implement such a
|
||||
method in the model or subclass and override this method.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
if hasattr(self.model, "floating_point_ops"):
|
||||
return self.model.floating_point_ops(inputs)
|
||||
else:
|
||||
return 0
|
||||
|
||||
#
|
||||
# Deprecated code
|
||||
#
|
||||
|
||||
def prediction_loop(
|
||||
self,
|
||||
@@ -2015,119 +2391,3 @@ class Trainer:
|
||||
tensors = distributed_concat(tensors)
|
||||
|
||||
return nested_numpify(tensors)
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
||||
logits and labels (each being optional).
|
||||
"""
|
||||
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
||||
if has_labels:
|
||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
with torch.no_grad():
|
||||
if is_sagemaker_mp_enabled():
|
||||
raw_outputs = smp_forward_only(model, inputs)
|
||||
if has_labels:
|
||||
if isinstance(raw_outputs, dict):
|
||||
loss_mb = raw_outputs["loss"]
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
loss_mb = raw_outputs[0]
|
||||
logits_mb = raw_outputs[1:]
|
||||
|
||||
loss = loss_mb.reduce_mean().detach().cpu()
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
loss = None
|
||||
if isinstance(raw_outputs, dict):
|
||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits_mb = raw_outputs
|
||||
logits = smp_nested_concat(logits_mb)
|
||||
else:
|
||||
if has_labels:
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
outputs = model(**inputs)
|
||||
else:
|
||||
outputs = model(**inputs)
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits = outputs
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index - 1]
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = nested_detach(logits)
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
||||
"""
|
||||
For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
|
||||
floating point operations for every backward + forward pass. If using another model, either implement such a
|
||||
method in the model or subclass and override this method.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
if hasattr(self.model, "floating_point_ops"):
|
||||
return self.model.floating_point_ops(inputs)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Callbacks to use with the Trainer class and customize the training loop.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
@@ -469,7 +469,7 @@ class ProgressCallback(TrainerCallback):
|
||||
self.current_step = state.global_step
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized):
|
||||
if self.prediction_bar is None:
|
||||
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
|
||||
self.prediction_bar.update(1)
|
||||
|
||||
@@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
|
||||
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
|
||||
|
||||
|
||||
def find_batch_size(tensors):
|
||||
"""
|
||||
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
|
||||
"""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
for t in tensors:
|
||||
result = find_batch_size(t)
|
||||
if result is not None:
|
||||
return result
|
||||
elif isinstance(tensors, dict):
|
||||
for key, value in tensors.items():
|
||||
result = find_batch_size(value)
|
||||
if result is not None:
|
||||
return result
|
||||
elif isinstance(tensors, torch.Tensor):
|
||||
return tensors.shape[0] if len(tensors.shape) >= 1 else None
|
||||
elif isinstance(tensors, np.ndarray):
|
||||
return tensors.shape[0] if len(tensors.shape) >= 1 else None
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
@@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler):
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
|
||||
warnings.warn(
|
||||
"SequentialDistributedSampler is deprecated and will be removed in v5 of Tranformers.",
|
||||
FutureWarning,
|
||||
)
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
@@ -338,6 +362,10 @@ class DistributedTensorGatherer:
|
||||
"""
|
||||
|
||||
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
|
||||
warnings.warn(
|
||||
"DistributedTensorGatherer is deprecated and will be removed in v5 of Tranformers.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.world_size = world_size
|
||||
self.num_samples = num_samples
|
||||
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
|
||||
@@ -576,6 +604,55 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class ShardSampler(Sampler):
|
||||
"""
|
||||
Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
|
||||
size 4, the first two batches are :obj:`[0, 1, 2, 3, 4, 5, 6, 7]` and :obj:`[8, 9, 10, 11, 12, 13, 14, 15]`, which
|
||||
shard into :obj:`[0, 1, 2, 3]` and :obj:`[8, 9, 10, 11]` for GPU-0 and :obj:`[4, 5, 6, 7]` and :obj:`[12, 13, 14,
|
||||
15]` for GPU-1.
|
||||
|
||||
The sampler thus yields :obj:`[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and :obj:`[4, 5, 6, 7, 12, 13, 14, 15]` on
|
||||
GPU-1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int = 1,
|
||||
drop_last: bool = False,
|
||||
num_processes: int = 1,
|
||||
process_index: int = 0,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
self.num_processes = num_processes
|
||||
self.process_index = process_index
|
||||
|
||||
self.total_batch_size = total_batch_size = batch_size * num_processes
|
||||
|
||||
num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
|
||||
self.total_num_samples = num_batches * total_batch_size
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
|
||||
# and it needs to be done several times.
|
||||
while len(indices) < self.total_num_samples:
|
||||
indices += indices[: (self.total_num_samples - len(indices))]
|
||||
|
||||
result = []
|
||||
for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
|
||||
result += indices[batch_start : batch_start + self.batch_size]
|
||||
|
||||
return iter(result)
|
||||
|
||||
def __len__(self):
|
||||
# Each shard only sees a fraction of total_num_samples.
|
||||
return self.total_num_samples // self.num_processes
|
||||
|
||||
|
||||
class IterableDatasetShard(IterableDataset):
|
||||
"""
|
||||
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
|
||||
@@ -634,6 +711,7 @@ class IterableDatasetShard(IterableDataset):
|
||||
self.process_index = process_index
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
self.num_examples = 0
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
@@ -641,6 +719,7 @@ class IterableDatasetShard(IterableDataset):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
self.num_examples = 0
|
||||
if (
|
||||
not hasattr(self.dataset, "set_epoch")
|
||||
and hasattr(self.dataset, "generator")
|
||||
@@ -653,6 +732,7 @@ class IterableDatasetShard(IterableDataset):
|
||||
first_batch = None
|
||||
current_batch = []
|
||||
for element in self.dataset:
|
||||
self.num_examples += 1
|
||||
current_batch.append(element)
|
||||
# Wait to have a full batch before yielding elements.
|
||||
if len(current_batch) == real_batch_size:
|
||||
|
||||
@@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple):
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class EvalLoopOutput(NamedTuple):
|
||||
predictions: Union[np.ndarray, Tuple[np.ndarray]]
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
num_samples: Optional[int]
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: Union[np.ndarray, Tuple[np.ndarray]]
|
||||
label_ids: Optional[np.ndarray]
|
||||
|
||||
@@ -524,6 +524,9 @@ class TrainingArguments:
|
||||
skip_memory_metrics: bool = field(
|
||||
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
|
||||
)
|
||||
use_legacy_prediction_loop: bool = field(
|
||||
default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
|
||||
)
|
||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||
mp_parameters: str = field(
|
||||
default="",
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
@@ -286,6 +287,8 @@ class NotebookProgressCallback(TrainerCallback):
|
||||
self._force_next_update = False
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
|
||||
return
|
||||
if self.prediction_bar is None:
|
||||
if self.training_tracker is not None:
|
||||
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
|
||||
|
||||
@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
def test_training_iterable_dataset(self):
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
train_dataset = SampleIterableDataset()
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.assertEqual(trainer.state.global_step, 4)
|
||||
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||
|
||||
# Exception if giving iterable dataset and no max_steps
|
||||
with self.assertRaises(ValueError):
|
||||
args1 = RegressionTrainingArguments(output_dir="./examples")
|
||||
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
|
||||
def test_evaluation_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Exception if eval_dataset is iterable in __init__
|
||||
with self.assertRaises(ValueError):
|
||||
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
results = trainer.evaluate()
|
||||
|
||||
# Exception if predicting with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.predict(train_dataset)
|
||||
x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
# Exception if evaluating with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.evaluate(train_dataset)
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
eval_dataset = SampleIterableDataset(length=66)
|
||||
results = trainer.evaluate(eval_dataset)
|
||||
|
||||
x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
def test_predict_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = eval_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
test_dataset = SampleIterableDataset(length=66)
|
||||
preds = trainer.predict(test_dataset).predictions
|
||||
x = test_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
def test_num_train_epochs_in_training(self):
|
||||
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -34,6 +35,7 @@ if is_torch_available():
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
ShardSampler,
|
||||
get_parameter_names,
|
||||
)
|
||||
|
||||
@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
# All shards have the same number of samples
|
||||
self.assertEqual(len(shard), len(shard_lists[0]))
|
||||
|
||||
for shard in shards:
|
||||
# All shards know the total number of samples
|
||||
self.assertEqual(shard.num_examples, len(reference))
|
||||
|
||||
observed = []
|
||||
for idx in range(0, len(shard_lists[0]), batch_size):
|
||||
for shard in shard_lists:
|
||||
@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
reference += reference
|
||||
self.assertListEqual(observed, reference[: len(observed)])
|
||||
|
||||
# Check equivalence between IterableDataset and ShardSampler
|
||||
dataset.generator.manual_seed(epoch)
|
||||
reference = list(dataset)
|
||||
|
||||
sampler_shards = [
|
||||
ShardSampler(
|
||||
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
|
||||
)
|
||||
for i in range(num_processes)
|
||||
]
|
||||
for shard, sampler_shard in zip(shard_lists, sampler_shards):
|
||||
self.assertListEqual(shard, list(sampler_shard))
|
||||
|
||||
def test_iterable_dataset_shard(self):
|
||||
dataset = RandomIterableDataset()
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
|
||||
|
||||
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
|
||||
shards = [
|
||||
ShardSampler(
|
||||
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
|
||||
)
|
||||
for i in range(num_processes)
|
||||
]
|
||||
shard_lists = [list(shard) for shard in shards]
|
||||
|
||||
for shard in shard_lists:
|
||||
# All shards have a number of samples that is a round multiple of batch size
|
||||
self.assertTrue(len(shard) % batch_size == 0)
|
||||
# All shards have the same number of samples
|
||||
self.assertEqual(len(shard), len(shard_lists[0]))
|
||||
|
||||
observed = []
|
||||
for idx in range(0, len(shard_lists[0]), batch_size):
|
||||
for shard in shard_lists:
|
||||
observed += shard[idx : idx + batch_size]
|
||||
|
||||
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
|
||||
# batch_size
|
||||
reference = copy.copy(dataset)
|
||||
if not drop_last:
|
||||
while len(reference) < len(observed):
|
||||
reference += reference
|
||||
self.assertListEqual(observed, reference[: len(observed)])
|
||||
|
||||
def test_shard_sampler(self):
|
||||
for n_elements in [64, 123]:
|
||||
dataset = list(range(n_elements))
|
||||
|
||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
|
||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)
|
||||
|
||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
|
||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
|
||||
|
||||
Reference in New Issue
Block a user