From 8cbd0bd137d9f35c6e909cd602e6dde5866b8574 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 15 Feb 2021 18:57:17 +0100 Subject: [PATCH] Specify dataset dtype (#10195) Co-authored-by: Quentin Lhoest Co-authored-by: Quentin Lhoest --- tests/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 48720ccc9f..d382dbc40b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -500,7 +500,7 @@ class TrainerIntegrationTest(unittest.TestCase): self.check_trained_model(trainer.model) # Can return tensors. - train_dataset.set_format(type="torch") + train_dataset.set_format(type="torch", dtype=torch.float32) model = RegressionModel() trainer = Trainer(model, args, train_dataset=train_dataset) trainer.train()