Fixing FLOPS merge by checking if torch is available (#7013)

* Should check if `torch` is available

* fixed samples_count error, distributed_concat arguments

* style

* Import torch at beginning of file

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Lysandre Debut
2020-09-08 16:51:58 +02:00
committed by GitHub
parent 01d340adfa
commit 5c4eb4b1ac
2 changed files with 34 additions and 25 deletions

View File

@@ -1315,8 +1315,6 @@ class Trainer:
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
if samples_count is not None:
samples_count = sum(xm.mesh_reduce("samples_count", torch.tensor([samples_count]), torch.cat).tolist())
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None: