From 197e7ce911d91d85eb2f91858720957c2d979cd2 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:12:28 +0100 Subject: [PATCH] Fix device issue in a `ConvBertModelTest` test (#21438) fix Co-authored-by: ydshieh --- tests/models/convbert/test_modeling_convbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/convbert/test_modeling_convbert.py b/tests/models/convbert/test_modeling_convbert.py index 49f363ff7b..b8ab2c6472 100644 --- a/tests/models/convbert/test_modeling_convbert.py +++ b/tests/models/convbert/test_modeling_convbert.py @@ -440,7 +440,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): def test_model_for_input_embeds(self): batch_size = 2 seq_length = 10 - inputs_embeds = torch.rand([batch_size, seq_length, 768]) + inputs_embeds = torch.rand([batch_size, seq_length, 768], device=torch_device) config = self.model_tester.get_config() model = ConvBertModel(config=config) model.to(torch_device)