Avoid accessing .dataset of a DataLoader in Trainer (#16451)

* Avoid accessing .dataset of a dataloader

* style

* fix

* cleaning up, reverting some misunderstandings

* black

* add train_dataset argument to get_train_dataloader, and fix other instances of length checks

* flake8

* address comments

* fix bug

* cleanup

* add test

* Update tests/trainer/test_trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* under torch

* merge

* stylistic suggestion

Co-authored-by: Sander Land <sander@chatdesk.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sander Land
2022-03-29 21:00:18 +02:00
committed by GitHub
parent 781af7362b
commit d7c8ce57d4
4 changed files with 69 additions and 32 deletions

View File

@@ -189,6 +189,26 @@ if is_torch_available():
yield self.dataset[self.current_sample]
self.current_sample += 1
class MultiLoader:
def __init__(self, loaders):
self.loaders = loaders
def __len__(self):
return sum(len(loader) for loader in self.loaders)
def __iter__(self):
for loader in self.loaders:
yield from loader
class CustomDataloaderTrainer(Trainer):
def get_train_dataloader(self):
dataloaders = [super().get_train_dataloader(), super().get_train_dataloader()]
return MultiLoader(dataloaders)
def get_eval_dataloader(self, eval_dataset):
dataloaders = [super().get_eval_dataloader(eval_dataset), super().get_eval_dataloader(eval_dataset)]
return MultiLoader(dataloaders)
class RegressionModel(nn.Module):
def __init__(self, a=0, b=0, double_output=False):
super().__init__()
@@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
new_eval_dataset = RegressionDataset(length=128)
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
# tests that we do not require dataloader to have a .dataset attribute
def test_dataloader_without_dataset(self):
train_dataset = RegressionDataset(length=128)
trainer = CustomDataloaderTrainer(
model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset
)
trainer.train()
trainer.evaluate()
def test_sampler_seed(self):
# nb: we don't want to inherit from IterableDataset to hit the right code path
class DummyDataset(torch.utils.data.Dataset):