Fix datasets set_format (#10178)

This commit is contained in:
Sylvain Gugger
2021-02-15 05:49:07 -05:00
committed by GitHub
parent 8fae93ca19
commit 587197dcd2

View File

@@ -439,7 +439,8 @@ class Trainer:
f"The following columns {dset_description}don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)
dataset.set_format(type=dataset.format["type"], columns=columns)
dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(