diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 678052b5cc..067d4ff305 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2261,17 +2261,7 @@ class Trainer: metrics = None if self.control.should_evaluate: - if isinstance(self.eval_dataset, dict): - metrics = {} - for eval_dataset_name, eval_dataset in self.eval_dataset.items(): - dataset_metrics = self.evaluate( - eval_dataset=eval_dataset, - ignore_keys=ignore_keys_for_eval, - metric_key_prefix=f"eval_{eval_dataset_name}", - ) - metrics.update(dataset_metrics) - else: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) # Run delayed LR scheduler now that metrics are populated @@ -2997,7 +2987,7 @@ class Trainer: def evaluate( self, - eval_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: @@ -3010,10 +3000,24 @@ class Trainer: You can also subclass and override this method to inject custom behavior. Args: - eval_dataset (`Dataset`, *optional*): + eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*): Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns - not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` - method. + not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will + evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the + `__len__` method. + + + + If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run + separate evaluations on each dataset. This can be useful to monitor how training affects other + datasets or simply to get a more fine-grained evaluation. + When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one + of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets + `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the + loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`. + + + ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. @@ -3025,6 +3029,19 @@ class Trainer: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ + # handle multipe eval datasets + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + if isinstance(eval_dataset, dict): + metrics = {} + for eval_dataset_name, _eval_dataset in eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=_eval_dataset, + ignore_keys=ignore_keys, + metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + return metrics + # memory metrics - must set up as early as possible self._memory_tracker.start() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 05f84bc00f..0ab0b78112 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -103,6 +103,7 @@ if is_torch_available(): import transformers.optimization from transformers import ( + AutoModelForCausalLM, AutoModelForSequenceClassification, EarlyStoppingCallback, GlueDataset, @@ -1845,6 +1846,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): result = trainer.evaluate() self.assertLess(result["eval_loss"], 0.2) + @slow + def test_trainer_eval_multiple(self): + MODEL_ID = "gpt2" + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + model = AutoModelForCausalLM.from_pretrained(MODEL_ID) + dataset = LineByLineTextDataset( + tokenizer=tokenizer, + file_path=PATH_SAMPLE_TEXT, + block_size=tokenizer.max_len_single_sentence, + ) + for example in dataset.examples: + example["labels"] = example["input_ids"] + training_args = TrainingArguments( + output_dir="./examples", + use_cpu=True, + per_device_eval_batch_size=1, + ) + trainer = Trainer( + model=model, + args=training_args, + eval_dataset={ + "data1": dataset, + "data2": dataset, + }, + ) + result = trainer.evaluate() + self.assertIn("eval_data1_loss", result) + self.assertIn("eval_data2_loss", result) + @slow def test_trainer_eval_lm(self): MODEL_ID = "distilroberta-base"