no numpy
This commit is contained in:
@@ -456,7 +456,7 @@ def main():
|
|||||||
preds = np.argmax(preds, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
elif output_mode == "regression":
|
elif output_mode == "regression":
|
||||||
preds = np.squeeze(preds)
|
preds = np.squeeze(preds)
|
||||||
result = compute_metrics(task_name, preds, out_label_ids.numpy())
|
result = compute_metrics(task_name, preds, out_label_ids)
|
||||||
|
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
# Average over distributed nodes if needed
|
# Average over distributed nodes if needed
|
||||||
@@ -533,7 +533,7 @@ def main():
|
|||||||
eval_loss = eval_loss / nb_eval_steps
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
preds = preds[0]
|
preds = preds[0]
|
||||||
preds = np.argmax(preds, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
result = compute_metrics(task_name, preds, out_label_ids.numpy())
|
result = compute_metrics(task_name, preds, out_label_ids)
|
||||||
|
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
# Average over distributed nodes if needed
|
# Average over distributed nodes if needed
|
||||||
|
|||||||
Reference in New Issue
Block a user