Avoid accessing .dataset of a DataLoader in Trainer (#16451)
* Avoid accessing .dataset of a dataloader * style * fix * cleaning up, reverting some misunderstandings * black * add train_dataset argument to get_train_dataloader, and fix other instances of length checks * flake8 * address comments * fix bug * cleanup * add test * Update tests/trainer/test_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * under torch * merge * stylistic suggestion Co-authored-by: Sander Land <sander@chatdesk.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -189,6 +189,26 @@ if is_torch_available():
|
||||
yield self.dataset[self.current_sample]
|
||||
self.current_sample += 1
|
||||
|
||||
class MultiLoader:
|
||||
def __init__(self, loaders):
|
||||
self.loaders = loaders
|
||||
|
||||
def __len__(self):
|
||||
return sum(len(loader) for loader in self.loaders)
|
||||
|
||||
def __iter__(self):
|
||||
for loader in self.loaders:
|
||||
yield from loader
|
||||
|
||||
class CustomDataloaderTrainer(Trainer):
|
||||
def get_train_dataloader(self):
|
||||
dataloaders = [super().get_train_dataloader(), super().get_train_dataloader()]
|
||||
return MultiLoader(dataloaders)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset):
|
||||
dataloaders = [super().get_eval_dataloader(eval_dataset), super().get_eval_dataloader(eval_dataset)]
|
||||
return MultiLoader(dataloaders)
|
||||
|
||||
class RegressionModel(nn.Module):
|
||||
def __init__(self, a=0, b=0, double_output=False):
|
||||
super().__init__()
|
||||
@@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
new_eval_dataset = RegressionDataset(length=128)
|
||||
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
||||
|
||||
# tests that we do not require dataloader to have a .dataset attribute
|
||||
def test_dataloader_without_dataset(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
trainer = CustomDataloaderTrainer(
|
||||
model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.evaluate()
|
||||
|
||||
def test_sampler_seed(self):
|
||||
# nb: we don't want to inherit from IterableDataset to hit the right code path
|
||||
class DummyDataset(torch.utils.data.Dataset):
|
||||
|
||||
Reference in New Issue
Block a user