Trainer callbacks (#7596)
* Initial callback proposal * Finish various callbacks * Post-rebase conflicts * Fix tests * Don't use something that's not set * Documentation * Remove unwanted print. * Document all models can work * Add tests + small fixes * Update docs/source/internal/trainer_utils.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Fix TF tests * Real fix this time * This one should work * Fix typo * Really fix typo Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
179
src/transformers/trainer_pt_utils.py
Normal file
179
src/transformers/trainer_pt_utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Torch utilities for the Trainer class.
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
|
||||
|
||||
|
||||
def nested_concat(tensors, new_tensors, dim=0):
|
||||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
||||
assert type(tensors) == type(
|
||||
new_tensors
|
||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
||||
return torch.cat((tensors, new_tensors), dim=dim)
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||
return tensors.cpu().numpy()
|
||||
|
||||
|
||||
def nested_detach(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
def nested_xla_mesh_reduce(tensors, name):
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||
else:
|
||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||
|
||||
|
||||
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
|
||||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
except AssertionError:
|
||||
raise AssertionError("Not currently using distributed training")
|
||||
|
||||
|
||||
def distributed_broadcast_scalars(
|
||||
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
if num_total_examples is not None:
|
||||
concat = concat[:num_total_examples]
|
||||
return concat
|
||||
except AssertionError:
|
||||
raise AssertionError("Not currently using distributed training")
|
||||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""
|
||||
Decorator to make all processes in distributed training wait for each local_master to do something.
|
||||
|
||||
Args:
|
||||
local_rank (:obj:`int`): The rank of the local process.
|
||||
"""
|
||||
if local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
yield
|
||||
if local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
"""
|
||||
Distributed Sampler that subsamples indicies sequentially,
|
||||
making it easier to collate all results at the end.
|
||||
|
||||
Even though we only use this sampler for eval and predict (no training),
|
||||
which means that the model params won't have to be synced (i.e. will not hang
|
||||
for synchronization even if varied number of forward passes), we still add extra
|
||||
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
||||
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert (
|
||||
len(indices) == self.total_size
|
||||
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||
assert (
|
||||
len(indices) == self.num_samples
|
||||
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
||||
if xm.xrt_world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
Reference in New Issue
Block a user