Fix device issue in a ConvBertModelTest test (#21438)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-03 15:12:28 +01:00
committed by GitHub
parent 0df802822c
commit 197e7ce911

View File

@@ -440,7 +440,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_for_input_embeds(self): def test_model_for_input_embeds(self):
batch_size = 2 batch_size = 2
seq_length = 10 seq_length = 10
inputs_embeds = torch.rand([batch_size, seq_length, 768]) inputs_embeds = torch.rand([batch_size, seq_length, 768], device=torch_device)
config = self.model_tester.get_config() config = self.model_tester.get_config()
model = ConvBertModel(config=config) model = ConvBertModel(config=config)
model.to(torch_device) model.to(torch_device)