@@ -1334,9 +1334,9 @@ class Trainer:
|
|||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||||
if preds is not None:
|
if preds is not None:
|
||||||
preds = nested_xla_mesh_reduce("eval_preds", preds)
|
preds = nested_xla_mesh_reduce(preds, "eval_preds")
|
||||||
if label_ids is not None:
|
if label_ids is not None:
|
||||||
label_ids = nested_xla_mesh_reduce("eval_label_ids", label_ids, torch.cat)
|
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
|
||||||
if eval_losses is not None:
|
if eval_losses is not None:
|
||||||
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
|
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user