@@ -1334,9 +1334,9 @@ class Trainer:
|
||||
elif is_torch_tpu_available():
|
||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||
if preds is not None:
|
||||
preds = nested_xla_mesh_reduce("eval_preds", preds)
|
||||
preds = nested_xla_mesh_reduce(preds, "eval_preds")
|
||||
if label_ids is not None:
|
||||
label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat)
|
||||
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
|
||||
if eval_losses is not None:
|
||||
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user