From 5c4eb4b1ac45291e89c1be0fb1fdacd841b19a47 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 8 Sep 2020 16:51:58 +0200 Subject: [PATCH] Fixing FLOPS merge by checking if torch is available (#7013) * Should check if `torch` is available * fixed samples_count error, distributed_concat arguments * style * Import torch at beginning of file Co-authored-by: TevenLeScao --- src/transformers/trainer.py | 2 -- src/transformers/trainer_utils.py | 57 ++++++++++++++++++------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 68c049643f..c1d1905e21 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1315,8 +1315,6 @@ class Trainer: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) if eval_losses is not None: eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() - if samples_count is not None: - samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist()) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 09149dc963..a1204fe8f1 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -2,12 +2,15 @@ import random from typing import Any, Dict, List, NamedTuple, Optional, Union import numpy as np -import torch from .file_utils import is_tf_available, is_torch_available from .tokenization_utils_base import ExplicitEnum +if is_torch_available(): + import torch + + def set_seed(seed: int): """ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` @@ -129,30 +132,38 @@ default_hp_space = { } -def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor: - assert self.args.local_rank != -1 +def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor": + if is_torch_available(): + try: + output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensor) + concat = torch.cat(output_tensors, dim=0) - output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensor) - concat = torch.cat(output_tensors, dim=0) - - # truncate the dummy elements added by SequentialDistributedSampler - if num_total_examples is not None: - concat = concat[:num_total_examples] - return concat + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + else: + raise ImportError("Torch must be installed to use `distributed_concat`") def distributed_broadcast_scalars( - self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None -) -> torch.Tensor: - assert self.args.local_rank != -1 + scalars: List[Union[int, float]], num_total_examples: Optional[int] = None +) -> "torch.Tensor": + if is_torch_available(): + try: + tensorized_scalar = torch.Tensor(scalars).cuda() + output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensorized_scalar) + concat = torch.cat(output_tensors, dim=0) - tensorized_scalar = torch.Tensor(scalars).cuda() - output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensorized_scalar) - concat = torch.cat(output_tensors, dim=0) - - # truncate the dummy elements added by SequentialDistributedSampler - if num_total_examples is not None: - concat = concat[:num_total_examples] - return concat + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + else: + raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")