Fix IterableDataset with __len__ in Trainer (#8095)

This commit is contained in:
Jonathan Chang
2020-10-27 21:52:35 +08:00
committed by GitHub
parent d93acd6f13
commit 286dc19a4f

View File

@@ -384,7 +384,9 @@ class Trainer:
dataset.set_format(type=dataset.format["type"], columns=columns) dataset.set_format(type=dataset.format["type"], columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized): if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
return None return None
elif is_torch_tpu_available(): elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset) return get_tpu_sampler(self.train_dataset)