Fixing FLOPS merge by checking if torch is available (#7013)

* Should check if `torch` is available

* fixed samples_count error, distributed_concat arguments

* style

* Import torch at beginning of file

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Lysandre Debut
2020-09-08 16:51:58 +02:00
committed by GitHub
parent 01d340adfa
commit 5c4eb4b1ac
2 changed files with 34 additions and 25 deletions

View File

@@ -1315,8 +1315,6 @@ class Trainer:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:

View File

@@ -2,12 +2,15 @@ import random
from typing import Any, Dict, List, NamedTuple, Optional, Union
import numpy as np
import torch
from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum
if is_torch_available():
import torch
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
@@ -129,9 +132,9 @@ default_hp_space = {
}
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor:
assert self.args.local_rank != -1
def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
if is_torch_available():
try:
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
@@ -140,13 +143,17 @@ def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_concat`")
def distributed_broadcast_scalars(
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> torch.Tensor:
assert self.args.local_rank != -1
scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
) -> "torch.Tensor":
if is_torch_available():
try:
tensorized_scalar = torch.Tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
@@ -156,3 +163,7 @@ def distributed_broadcast_scalars(
if num_total_examples is not None:
concat = concat[:num_total_examples]
return concat
except AssertionError:
raise AssertionError("Not currently using distributed training")
else:
raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")