Fixing FLOPS merge by checking if torch is available (#7013)

* Should check if `torch` is available

* fixed samples_count error, distributed_concat arguments

* style

* Import torch at beginning of file

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Lysandre Debut
2020-09-08 16:51:58 +02:00
committed by GitHub
parent 01d340adfa
commit 5c4eb4b1ac
2 changed files with 34 additions and 25 deletions

View File

@@ -1315,8 +1315,6 @@ class Trainer:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None: if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
# Finally, turn the aggregated tensors into numpy arrays. # Finally, turn the aggregated tensors into numpy arrays.
if preds is not None: if preds is not None:

View File

@@ -2,12 +2,15 @@ import random
from typing import Any, Dict, List, NamedTuple, Optional, Union from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch
from .file_utils import is_tf_available, is_torch_available from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum from .tokenization_utils_base import ExplicitEnum
if is_torch_available():
import torch
def set_seed(seed: int): def set_seed(seed: int):
""" """
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
@@ -129,30 +132,38 @@ default_hp_space = {
} }
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor: def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
assert self.args.local_rank != -1 if is_torch_available():
try:
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] # truncate the dummy elements added by SequentialDistributedSampler
torch.distributed.all_gather(output_tensors, tensor) if num_total_examples is not None:
concat = torch.cat(output_tensors, dim=0) concat = concat[:num_total_examples]
return concat
# truncate the dummy elements added by SequentialDistributedSampler except AssertionError:
if num_total_examples is not None: raise AssertionError("Not currently using distributed training")
concat = concat[:num_total_examples] else:
return concat raise ImportError("Torch must be installed to use `distributed_concat`")
def distributed_broadcast_scalars( def distributed_broadcast_scalars(
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor: ) -> "torch.Tensor":
assert self.args.local_rank != -1 if is_torch_available():
try:
tensorized_scalar = torch.Tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
tensorized_scalar = torch.Tensor(scalars).cuda() # truncate the dummy elements added by SequentialDistributedSampler
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] if num_total_examples is not None:
torch.distributed.all_gather(output_tensors, tensorized_scalar) concat = concat[:num_total_examples]
concat = torch.cat(output_tensors, dim=0) return concat
except AssertionError:
# truncate the dummy elements added by SequentialDistributedSampler raise AssertionError("Not currently using distributed training")
if num_total_examples is not None: else:
concat = concat[:num_total_examples] raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
return concat