Support multiple validation datasets when dataloader_persistent_workers=True (#30627)
* Support multiple validation datasets when dataloader_persistent_workers=True * Test support of multiple validation datasets
This commit is contained in:
committed by
GitHub
parent
147c404fb1
commit
485fd81471
@@ -919,25 +919,36 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
||||||
|
|
||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
|
||||||
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
|
||||||
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
|
||||||
"""
|
"""
|
||||||
if eval_dataset is None and self.eval_dataset is None:
|
if eval_dataset is None and self.eval_dataset is None:
|
||||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||||
|
|
||||||
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
||||||
# don't change during training
|
# don't change during training
|
||||||
if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
|
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
||||||
return self.accelerator.prepare(self._eval_dataloader)
|
if (
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
hasattr(self, "_eval_dataloaders")
|
||||||
|
and dataloader_key in self._eval_dataloaders
|
||||||
|
and self.args.dataloader_persistent_workers
|
||||||
|
):
|
||||||
|
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
||||||
|
|
||||||
|
eval_dataset = (
|
||||||
|
self.eval_dataset[eval_dataset]
|
||||||
|
if isinstance(eval_dataset, str)
|
||||||
|
else eval_dataset
|
||||||
|
if eval_dataset is not None
|
||||||
|
else self.eval_dataset
|
||||||
|
)
|
||||||
data_collator = self.data_collator
|
data_collator = self.data_collator
|
||||||
|
|
||||||
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||||
@@ -962,7 +973,10 @@ class Trainer:
|
|||||||
# we need to store the non-prepared version
|
# we need to store the non-prepared version
|
||||||
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
||||||
if self.args.dataloader_persistent_workers:
|
if self.args.dataloader_persistent_workers:
|
||||||
self._eval_dataloader = eval_dataloader
|
if hasattr(self, "_eval_dataloaders"):
|
||||||
|
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
||||||
|
else:
|
||||||
|
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
||||||
|
|
||||||
return self.accelerator.prepare(eval_dataloader)
|
return self.accelerator.prepare(eval_dataloader)
|
||||||
|
|
||||||
@@ -3584,12 +3598,13 @@ class Trainer:
|
|||||||
dictionary also contains the epoch number which comes from the training state.
|
dictionary also contains the epoch number which comes from the training state.
|
||||||
"""
|
"""
|
||||||
# handle multipe eval datasets
|
# handle multipe eval datasets
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
override = eval_dataset is not None
|
||||||
|
eval_dataset = eval_dataset if override else self.eval_dataset
|
||||||
if isinstance(eval_dataset, dict):
|
if isinstance(eval_dataset, dict):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for eval_dataset_name, _eval_dataset in eval_dataset.items():
|
for eval_dataset_name, _eval_dataset in eval_dataset.items():
|
||||||
dataset_metrics = self.evaluate(
|
dataset_metrics = self.evaluate(
|
||||||
eval_dataset=_eval_dataset,
|
eval_dataset=_eval_dataset if override else eval_dataset_name,
|
||||||
ignore_keys=ignore_keys,
|
ignore_keys=ignore_keys,
|
||||||
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
|
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1231,6 +1231,97 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
trainer.evaluate()
|
trainer.evaluate()
|
||||||
|
|
||||||
|
def test_get_eval_dataloader_without_persistent_workers(self):
|
||||||
|
train_dataset = RegressionDataset()
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
args = TrainingArguments("./test", report_to="none", dataloader_persistent_workers=False)
|
||||||
|
|
||||||
|
# Single evaluation dataset
|
||||||
|
eval_dataset = RegressionDataset()
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
|
||||||
|
trainer.accelerator.prepare = lambda x: x
|
||||||
|
|
||||||
|
default_dataloader = trainer.get_eval_dataloader()
|
||||||
|
dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
self.assertEqual(default_dataloader.dataset, eval_dataset)
|
||||||
|
self.assertEqual(dataloader_with_dataset.dataset, eval_dataset)
|
||||||
|
self.assertNotEqual(default_dataloader, dataloader_with_dataset)
|
||||||
|
|
||||||
|
# Multiple evaluation datasets
|
||||||
|
first_dataset = RegressionDataset()
|
||||||
|
second_dataset = RegressionDataset()
|
||||||
|
trainer = Trainer(
|
||||||
|
tiny_gpt2,
|
||||||
|
args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset={"first": first_dataset, "second": second_dataset},
|
||||||
|
)
|
||||||
|
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
|
||||||
|
trainer.accelerator.prepare = lambda x: x
|
||||||
|
|
||||||
|
first_dataloader = trainer.get_eval_dataloader("first")
|
||||||
|
first_dataloader_repeated = trainer.get_eval_dataloader("first")
|
||||||
|
second_dataloader = trainer.get_eval_dataloader("second")
|
||||||
|
second_dataloader_repeated = trainer.get_eval_dataloader("second")
|
||||||
|
|
||||||
|
self.assertEqual(first_dataset, first_dataloader.dataset)
|
||||||
|
self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset)
|
||||||
|
self.assertEqual(second_dataset, second_dataloader.dataset)
|
||||||
|
self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset)
|
||||||
|
self.assertNotEqual(first_dataloader, first_dataloader_repeated)
|
||||||
|
self.assertNotEqual(second_dataloader, second_dataloader_repeated)
|
||||||
|
|
||||||
|
def test_get_eval_dataloader_with_persistent_workers(self):
|
||||||
|
train_dataset = RegressionDataset()
|
||||||
|
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||||
|
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||||
|
args = TrainingArguments(
|
||||||
|
"./test",
|
||||||
|
report_to="none",
|
||||||
|
dataloader_persistent_workers=True,
|
||||||
|
dataloader_num_workers=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single evaluation dataset
|
||||||
|
eval_dataset = RegressionDataset()
|
||||||
|
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||||
|
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
|
||||||
|
trainer.accelerator.prepare = lambda x: x
|
||||||
|
|
||||||
|
default_dataloader = trainer.get_eval_dataloader()
|
||||||
|
dataloader_with_dataset = trainer.get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
self.assertEqual(default_dataloader.dataset, eval_dataset)
|
||||||
|
self.assertEqual(dataloader_with_dataset.dataset, eval_dataset)
|
||||||
|
self.assertEqual(default_dataloader, dataloader_with_dataset)
|
||||||
|
|
||||||
|
# Multiple evaluation datasets
|
||||||
|
first_dataset = RegressionDataset()
|
||||||
|
second_dataset = RegressionDataset()
|
||||||
|
trainer = Trainer(
|
||||||
|
tiny_gpt2,
|
||||||
|
args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset={"first": first_dataset, "second": second_dataset},
|
||||||
|
)
|
||||||
|
# Mocking the prepare method to avoid the dataloader changing with each call to get_eval_dataloader
|
||||||
|
trainer.accelerator.prepare = lambda x: x
|
||||||
|
|
||||||
|
first_dataloader = trainer.get_eval_dataloader("first")
|
||||||
|
first_dataloader_repeated = trainer.get_eval_dataloader("first")
|
||||||
|
second_dataloader = trainer.get_eval_dataloader("second")
|
||||||
|
second_dataloader_repeated = trainer.get_eval_dataloader("second")
|
||||||
|
|
||||||
|
self.assertEqual(first_dataset, first_dataloader.dataset)
|
||||||
|
self.assertEqual(first_dataloader.dataset, first_dataloader_repeated.dataset)
|
||||||
|
self.assertEqual(second_dataset, second_dataloader.dataset)
|
||||||
|
self.assertEqual(second_dataloader.dataset, second_dataloader_repeated.dataset)
|
||||||
|
self.assertEqual(first_dataloader, first_dataloader_repeated)
|
||||||
|
self.assertEqual(second_dataloader, second_dataloader_repeated)
|
||||||
|
|
||||||
@require_lomo
|
@require_lomo
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_lomo(self):
|
def test_lomo(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user