Fix IterableDataset with __len__ in Trainer (#8095)
This commit is contained in:
@@ -384,7 +384,9 @@ class Trainer:
|
|||||||
dataset.set_format(type=dataset.format["type"], columns=columns)
|
dataset.set_format(type=dataset.format["type"], columns=columns)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
|
||||||
|
self.train_dataset, collections.abc.Sized
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
return get_tpu_sampler(self.train_dataset)
|
return get_tpu_sampler(self.train_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user