Allow soft dependencies in the namespace with ImportErrors at use (#7537)
* PoC on RAG * Format class name/obj name * Better name in message * PoC on one TF model * Add PyTorch and TF dummy objects + script * Treat scikit-learn * Bad copy pastes * Typo
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from .metrics import is_sklearn_available
|
||||
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
||||
from .processors import (
|
||||
DataProcessor,
|
||||
InputExample,
|
||||
@@ -21,7 +21,3 @@ from .processors import (
|
||||
xnli_processors,
|
||||
xnli_tasks_num_labels,
|
||||
)
|
||||
|
||||
|
||||
if is_sklearn_available():
|
||||
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
||||
|
||||
@@ -14,77 +14,75 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
try:
|
||||
from ...file_utils import is_sklearn_available, requires_sklearn
|
||||
|
||||
|
||||
if is_sklearn_available():
|
||||
from sklearn.metrics import f1_score, matthews_corrcoef
|
||||
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
|
||||
_has_sklearn = True
|
||||
except (AttributeError, ImportError):
|
||||
_has_sklearn = False
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
requires_sklearn(simple_accuracy)
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
def acc_and_f1(preds, labels):
|
||||
requires_sklearn(acc_and_f1)
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
if _has_sklearn:
|
||||
def pearson_and_spearman(preds, labels):
|
||||
requires_sklearn(pearson_and_spearman)
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
def glue_compute_metrics(task_name, preds, labels):
|
||||
requires_sklearn(glue_compute_metrics)
|
||||
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"mnli/acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "hans":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
def glue_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(
|
||||
labels
|
||||
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"mnli/acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "hans":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
def xnli_compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(
|
||||
labels
|
||||
), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
||||
if task_name == "xnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
def xnli_compute_metrics(task_name, preds, labels):
|
||||
requires_sklearn(xnli_compute_metrics)
|
||||
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
||||
if task_name == "xnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
Reference in New Issue
Block a user