@@ -406,7 +406,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
output = model(**input_dict)
|
||||
self.assertEqual(
|
||||
|
||||
Reference in New Issue
Block a user