fix logit-to-multi-hot conversion in example (#26936)
* fix logit to multi-hot converstion * add comments * typo
This commit is contained in:
@@ -655,7 +655,7 @@ def main():
|
|||||||
preds = np.squeeze(preds)
|
preds = np.squeeze(preds)
|
||||||
result = metric.compute(predictions=preds, references=p.label_ids)
|
result = metric.compute(predictions=preds, references=p.label_ids)
|
||||||
elif is_multi_label:
|
elif is_multi_label:
|
||||||
preds = np.array([np.where(p > 0.5, 1, 0) for p in preds])
|
preds = np.array([np.where(p > 0, 1, 0) for p in preds]) # convert logits to multi-hot encoding
|
||||||
# Micro F1 is commonly used in multi-label classification
|
# Micro F1 is commonly used in multi-label classification
|
||||||
result = metric.compute(predictions=preds, references=p.label_ids, average="micro")
|
result = metric.compute(predictions=preds, references=p.label_ids, average="micro")
|
||||||
else:
|
else:
|
||||||
@@ -721,7 +721,10 @@ def main():
|
|||||||
if is_regression:
|
if is_regression:
|
||||||
predictions = np.squeeze(predictions)
|
predictions = np.squeeze(predictions)
|
||||||
elif is_multi_label:
|
elif is_multi_label:
|
||||||
predictions = np.array([np.where(p > 0.5, 1, 0) for p in predictions])
|
# Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied.
|
||||||
|
# You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer
|
||||||
|
# and set p > 0.5 below (less efficient in this case)
|
||||||
|
predictions = np.array([np.where(p > 0, 1, 0) for p in predictions])
|
||||||
else:
|
else:
|
||||||
predictions = np.argmax(predictions, axis=1)
|
predictions = np.argmax(predictions, axis=1)
|
||||||
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
||||||
|
|||||||
Reference in New Issue
Block a user