move code to Trainer.evaluate to enable use of that function with multiple datasets (#27844)
* move code to Trainer.evaluate to enable use of that function with multiple datasets * test * update doc string * and a tip * forgot the type --------- Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
This commit is contained in:
@@ -2261,16 +2261,6 @@ class Trainer:
|
|||||||
|
|
||||||
metrics = None
|
metrics = None
|
||||||
if self.control.should_evaluate:
|
if self.control.should_evaluate:
|
||||||
if isinstance(self.eval_dataset, dict):
|
|
||||||
metrics = {}
|
|
||||||
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
|
|
||||||
dataset_metrics = self.evaluate(
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
ignore_keys=ignore_keys_for_eval,
|
|
||||||
metric_key_prefix=f"eval_{eval_dataset_name}",
|
|
||||||
)
|
|
||||||
metrics.update(dataset_metrics)
|
|
||||||
else:
|
|
||||||
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
||||||
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
||||||
|
|
||||||
@@ -2997,7 +2987,7 @@ class Trainer:
|
|||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||||
ignore_keys: Optional[List[str]] = None,
|
ignore_keys: Optional[List[str]] = None,
|
||||||
metric_key_prefix: str = "eval",
|
metric_key_prefix: str = "eval",
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
@@ -3010,10 +3000,24 @@ class Trainer:
|
|||||||
You can also subclass and override this method to inject custom behavior.
|
You can also subclass and override this method to inject custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (`Dataset`, *optional*):
|
eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
|
||||||
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
|
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
|
||||||
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
|
not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
|
||||||
method.
|
evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
|
||||||
|
`__len__` method.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
|
||||||
|
separate evaluations on each dataset. This can be useful to monitor how training affects other
|
||||||
|
datasets or simply to get a more fine-grained evaluation.
|
||||||
|
When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
|
||||||
|
of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
|
||||||
|
`data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
|
||||||
|
loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
ignore_keys (`List[str]`, *optional*):
|
ignore_keys (`List[str]`, *optional*):
|
||||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||||
gathering predictions.
|
gathering predictions.
|
||||||
@@ -3025,6 +3029,19 @@ class Trainer:
|
|||||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||||
dictionary also contains the epoch number which comes from the training state.
|
dictionary also contains the epoch number which comes from the training state.
|
||||||
"""
|
"""
|
||||||
|
# handle multipe eval datasets
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
if isinstance(eval_dataset, dict):
|
||||||
|
metrics = {}
|
||||||
|
for eval_dataset_name, _eval_dataset in eval_dataset.items():
|
||||||
|
dataset_metrics = self.evaluate(
|
||||||
|
eval_dataset=_eval_dataset,
|
||||||
|
ignore_keys=ignore_keys,
|
||||||
|
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
|
||||||
|
)
|
||||||
|
metrics.update(dataset_metrics)
|
||||||
|
return metrics
|
||||||
|
|
||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
|
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
import transformers.optimization
|
import transformers.optimization
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
@@ -1845,6 +1846,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
result = trainer.evaluate()
|
result = trainer.evaluate()
|
||||||
self.assertLess(result["eval_loss"], 0.2)
|
self.assertLess(result["eval_loss"], 0.2)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_trainer_eval_multiple(self):
|
||||||
|
MODEL_ID = "gpt2"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
||||||
|
dataset = LineByLineTextDataset(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
file_path=PATH_SAMPLE_TEXT,
|
||||||
|
block_size=tokenizer.max_len_single_sentence,
|
||||||
|
)
|
||||||
|
for example in dataset.examples:
|
||||||
|
example["labels"] = example["input_ids"]
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir="./examples",
|
||||||
|
use_cpu=True,
|
||||||
|
per_device_eval_batch_size=1,
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
eval_dataset={
|
||||||
|
"data1": dataset,
|
||||||
|
"data2": dataset,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = trainer.evaluate()
|
||||||
|
self.assertIn("eval_data1_loss", result)
|
||||||
|
self.assertIn("eval_data2_loss", result)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_trainer_eval_lm(self):
|
def test_trainer_eval_lm(self):
|
||||||
MODEL_ID = "distilroberta-base"
|
MODEL_ID = "distilroberta-base"
|
||||||
|
|||||||
Reference in New Issue
Block a user