diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 34cf5aa490..978f1a6708 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -919,25 +919,36 @@ class Trainer: else: return None - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. Subclass and override this method if you want to inject some custom behavior. Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. + eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*): + If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") # If we have persistent workers, don't do a fork bomb especially as eval datasets # don't change during training - if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers: - return self.accelerator.prepare(self._eval_dataloader) - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) data_collator = self.data_collator if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): @@ -962,7 +973,10 @@ class Trainer: # we need to store the non-prepared version eval_dataloader = DataLoader(eval_dataset, **dataloader_params) if self.args.dataloader_persistent_workers: - self._eval_dataloader = eval_dataloader + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} return self.accelerator.prepare(eval_dataloader) @@ -3584,12 +3598,13 @@ class Trainer: dictionary also contains the epoch number which comes from the training state. """ # handle multipe eval datasets - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + override = eval_dataset is not None + eval_dataset = eval_dataset if override else self.eval_dataset if isinstance(eval_dataset, dict): metrics = {} for eval_dataset_name, _eval_dataset in eval_dataset.items(): dataset_metrics = self.evaluate( - eval_dataset=_eval_dataset, + eval_dataset=_eval_dataset if override else eval_dataset_name, ignore_keys=ignore_keys, metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index af456a9bda..dba4b97515 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1231,6 +1231,97 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() trainer.evaluate() + def test_get_eval_dataloader_without_persistent_workers(self): + train_dataset = RegressionDataset() + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + args = TrainingArguments("./test", report_to="none", dataloader_persistent_workers=False) + + # Single evaluation dataset + eval_dataset = RegressionDataset() + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + default_dataloader = trainer.get_eval_dataloader() + dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset) + + self.assertEqual(default_dataloader.dataset, eval_dataset) + self.assertEqual(dataloader_with_dataset.dataset, eval_dataset) + self.assertNotEqual(default_dataloader, dataloader_with_dataset) + + # Multiple evaluation datasets + first_dataset = RegressionDataset() + second_dataset = RegressionDataset() + trainer = Trainer( + tiny_gpt2, + args, + train_dataset=train_dataset, + eval_dataset={"first": first_dataset, "second": second_dataset}, + ) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + first_dataloader = trainer.get_eval_dataloader("first") + first_dataloader_repeated = trainer.get_eval_dataloader("first") + second_dataloader = trainer.get_eval_dataloader("second") + second_dataloader_repeated = trainer.get_eval_dataloader("second") + + self.assertEqual(first_dataset, first_dataloader.dataset) + self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset) + self.assertEqual(second_dataset, second_dataloader.dataset) + self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset) + self.assertNotEqual(first_dataloader, first_dataloader_repeated) + self.assertNotEqual(second_dataloader, second_dataloader_repeated) + + def test_get_eval_dataloader_with_persistent_workers(self): + train_dataset = RegressionDataset() + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + args = TrainingArguments( + "./test", + report_to="none", + dataloader_persistent_workers=True, + dataloader_num_workers=2, + ) + + # Single evaluation dataset + eval_dataset = RegressionDataset() + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + default_dataloader = trainer.get_eval_dataloader() + dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset) + + self.assertEqual(default_dataloader.dataset, eval_dataset) + self.assertEqual(dataloader_with_dataset.dataset, eval_dataset) + self.assertEqual(default_dataloader, dataloader_with_dataset) + + # Multiple evaluation datasets + first_dataset = RegressionDataset() + second_dataset = RegressionDataset() + trainer = Trainer( + tiny_gpt2, + args, + train_dataset=train_dataset, + eval_dataset={"first": first_dataset, "second": second_dataset}, + ) + # Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader + trainer.accelerator.prepare = lambda x: x + + first_dataloader = trainer.get_eval_dataloader("first") + first_dataloader_repeated = trainer.get_eval_dataloader("first") + second_dataloader = trainer.get_eval_dataloader("second") + second_dataloader_repeated = trainer.get_eval_dataloader("second") + + self.assertEqual(first_dataset, first_dataloader.dataset) + self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset) + self.assertEqual(second_dataset, second_dataloader.dataset) + self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset) + self.assertEqual(first_dataloader, first_dataloader_repeated) + self.assertEqual(second_dataloader, second_dataloader_repeated) + @require_lomo @require_torch_gpu def test_lomo(self):