From f71c9ccf592730fd9c733da56915569e2e8753aa Mon Sep 17 00:00:00 2001 From: YQ Date: Mon, 23 Oct 2023 18:33:05 +0800 Subject: [PATCH] fix logit-to-multi-hot conversion in example (#26936) * fix logit to multi-hot converstion * add comments * typo --- examples/pytorch/text-classification/run_classification.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/text-classification/run_classification.py b/examples/pytorch/text-classification/run_classification.py index 3033a61404..1bc4dbe5fa 100755 --- a/examples/pytorch/text-classification/run_classification.py +++ b/examples/pytorch/text-classification/run_classification.py @@ -655,7 +655,7 @@ def main(): preds = np.squeeze(preds) result = metric.compute(predictions=preds, references=p.label_ids) elif is_multi_label: - preds = np.array([np.where(p > 0.5, 1, 0) for p in preds]) + preds = np.array([np.where(p > 0, 1, 0) for p in preds]) # convert logits to multi-hot encoding # Micro F1 is commonly used in multi-label classification result = metric.compute(predictions=preds, references=p.label_ids, average="micro") else: @@ -721,7 +721,10 @@ def main(): if is_regression: predictions = np.squeeze(predictions) elif is_multi_label: - predictions = np.array([np.where(p > 0.5, 1, 0) for p in predictions]) + # Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied. + # You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer + # and set p > 0.5 below (less efficient in this case) + predictions = np.array([np.where(p > 0, 1, 0) for p in predictions]) else: predictions = np.argmax(predictions, axis=1) output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")