update test (#16219)
This commit is contained in:
committed by
GitHub
parent
7e0d04bed1
commit
d9b8d1a9f5
@@ -66,7 +66,9 @@ class MaskFormerModelTester:
|
||||
self.mask_feature_size = mask_feature_size
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size])
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
|
||||
|
||||
@@ -232,12 +234,12 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_model_with_labels(self):
|
||||
size = (self.model_tester.min_size,) * 2
|
||||
inputs = {
|
||||
"pixel_values": torch.randn((2, 3, *size)),
|
||||
"mask_labels": torch.randn((2, 10, *size)),
|
||||
"class_labels": torch.zeros(2, 10).long(),
|
||||
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
|
||||
"mask_labels": torch.randn((2, 10, *size), device=torch_device),
|
||||
"class_labels": torch.zeros(2, 10, device=torch_device).long(),
|
||||
}
|
||||
|
||||
model = MaskFormerForInstanceSegmentation(MaskFormerConfig())
|
||||
model = MaskFormerForInstanceSegmentation(MaskFormerConfig()).to(torch_device)
|
||||
outputs = model(**inputs)
|
||||
self.assertTrue(outputs.loss is not None)
|
||||
|
||||
@@ -249,7 +251,7 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(config).to(torch_device)
|
||||
outputs = model(**inputs, output_attentions=True)
|
||||
self.assertTrue(outputs.attentions is not None)
|
||||
|
||||
@@ -381,7 +383,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
|
||||
)
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
# class_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
@@ -392,7 +394,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
[3.6169e-02, -5.9025e00, -2.9313e00],
|
||||
[1.0766e-04, -7.7630e00, -5.1263e00],
|
||||
]
|
||||
)
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_with_annotations_and_loss(self):
|
||||
@@ -406,7 +408,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||
],
|
||||
return_tensors="pt",
|
||||
)
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
Reference in New Issue
Block a user