Fix device issue in a ConvBertModelTest test (#21438)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -440,7 +440,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_for_input_embeds(self):
|
def test_model_for_input_embeds(self):
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
seq_length = 10
|
seq_length = 10
|
||||||
inputs_embeds = torch.rand([batch_size, seq_length, 768])
|
inputs_embeds = torch.rand([batch_size, seq_length, 768], device=torch_device)
|
||||||
config = self.model_tester.get_config()
|
config = self.model_tester.get_config()
|
||||||
model = ConvBertModel(config=config)
|
model = ConvBertModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user