Specify dataset dtype (#10195)
Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
@@ -500,7 +500,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# Can return tensors.
|
||||
train_dataset.set_format(type="torch")
|
||||
train_dataset.set_format(type="torch", dtype=torch.float32)
|
||||
model = RegressionModel()
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user