Add cache_dir to save features in GLUE + Differentiate match/mismatch for MNLI metrics (#4621)
* Glue task cleaup * Enable writing cache to cache_dir in case dataset lives in readOnly filesystem. * Differentiate match vs mismatch for MNLI metrics. * Style * Fix pytype * Fix type * Use cache_dir in mnli mismatch eval dataset * Small Tweaks Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Optional
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -134,16 +134,29 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
train_dataset = (
|
||||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
|
GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
|
||||||
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None
|
)
|
||||||
|
eval_dataset = (
|
||||||
|
GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
|
||||||
|
if training_args.do_eval
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
test_dataset = (
|
||||||
|
GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
|
||||||
|
if training_args.do_predict
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||||
|
def compute_metrics_fn(p: EvalPrediction):
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
preds = np.argmax(p.predictions, axis=1)
|
preds = np.argmax(p.predictions, axis=1)
|
||||||
elif output_mode == "regression":
|
elif output_mode == "regression":
|
||||||
preds = np.squeeze(p.predictions)
|
preds = np.squeeze(p.predictions)
|
||||||
return glue_compute_metrics(data_args.task_name, preds, p.label_ids)
|
return glue_compute_metrics(task_name, preds, p.label_ids)
|
||||||
|
|
||||||
|
return compute_metrics_fn
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
@@ -151,7 +164,7 @@ def main():
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=build_compute_metrics_fn(data_args.task_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
@@ -174,9 +187,12 @@ def main():
|
|||||||
eval_datasets = [eval_dataset]
|
eval_datasets = [eval_dataset]
|
||||||
if data_args.task_name == "mnli":
|
if data_args.task_name == "mnli":
|
||||||
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||||
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev"))
|
eval_datasets.append(
|
||||||
|
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
|
||||||
|
)
|
||||||
|
|
||||||
for eval_dataset in eval_datasets:
|
for eval_dataset in eval_datasets:
|
||||||
|
trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
|
||||||
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
|
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||||
|
|
||||||
output_eval_file = os.path.join(
|
output_eval_file = os.path.join(
|
||||||
@@ -196,7 +212,9 @@ def main():
|
|||||||
test_datasets = [test_dataset]
|
test_datasets = [test_dataset]
|
||||||
if data_args.task_name == "mnli":
|
if data_args.task_name == "mnli":
|
||||||
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||||
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test"))
|
test_datasets.append(
|
||||||
|
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
|
||||||
|
)
|
||||||
|
|
||||||
for test_dataset in test_datasets:
|
for test_dataset in test_datasets:
|
||||||
predictions = trainer.predict(test_dataset=test_dataset).predictions
|
predictions = trainer.predict(test_dataset=test_dataset).predictions
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class GlueDataset(Dataset):
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
limit_length: Optional[int] = None,
|
limit_length: Optional[int] = None,
|
||||||
mode: Union[str, Split] = Split.train,
|
mode: Union[str, Split] = Split.train,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.processor = glue_processors[args.task_name]()
|
self.processor = glue_processors[args.task_name]()
|
||||||
@@ -81,7 +82,7 @@ class GlueDataset(Dataset):
|
|||||||
raise KeyError("mode is not a valid split name")
|
raise KeyError("mode is not a valid split name")
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
args.data_dir,
|
cache_dir if cache_dir is not None else args.data_dir,
|
||||||
"cached_{}_{}_{}_{}".format(
|
"cached_{}_{}_{}_{}".format(
|
||||||
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -63,9 +63,9 @@ if _has_sklearn:
|
|||||||
elif task_name == "qqp":
|
elif task_name == "qqp":
|
||||||
return acc_and_f1(preds, labels)
|
return acc_and_f1(preds, labels)
|
||||||
elif task_name == "mnli":
|
elif task_name == "mnli":
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
return {"mnli/acc": simple_accuracy(preds, labels)}
|
||||||
elif task_name == "mnli-mm":
|
elif task_name == "mnli-mm":
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
|
||||||
elif task_name == "qnli":
|
elif task_name == "qnli":
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
elif task_name == "rte":
|
elif task_name == "rte":
|
||||||
|
|||||||
@@ -553,6 +553,7 @@ class Trainer:
|
|||||||
if self.tb_writer:
|
if self.tb_writer:
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
self.tb_writer.add_scalar(k, v, self.global_step)
|
self.tb_writer.add_scalar(k, v, self.global_step)
|
||||||
|
self.tb_writer.flush()
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
wandb.log(logs, step=self.global_step)
|
wandb.log(logs, step=self.global_step)
|
||||||
output = json.dumps({**logs, **{"step": self.global_step}})
|
output = json.dumps({**logs, **{"step": self.global_step}})
|
||||||
|
|||||||
Reference in New Issue
Block a user