update test (#16219)
This commit is contained in:
committed by
GitHub
parent
7e0d04bed1
commit
d9b8d1a9f5
@@ -66,7 +66,9 @@ class MaskFormerModelTester:
|
|||||||
self.mask_feature_size = mask_feature_size
|
self.mask_feature_size = mask_feature_size
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size])
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
|
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
|
||||||
|
|
||||||
@@ -232,12 +234,12 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_model_with_labels(self):
|
def test_model_with_labels(self):
|
||||||
size = (self.model_tester.min_size,) * 2
|
size = (self.model_tester.min_size,) * 2
|
||||||
inputs = {
|
inputs = {
|
||||||
"pixel_values": torch.randn((2, 3, *size)),
|
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
|
||||||
"mask_labels": torch.randn((2, 10, *size)),
|
"mask_labels": torch.randn((2, 10, *size), device=torch_device),
|
||||||
"class_labels": torch.zeros(2, 10).long(),
|
"class_labels": torch.zeros(2, 10, device=torch_device).long(),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = MaskFormerForInstanceSegmentation(MaskFormerConfig())
|
model = MaskFormerForInstanceSegmentation(MaskFormerConfig()).to(torch_device)
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
self.assertTrue(outputs.loss is not None)
|
self.assertTrue(outputs.loss is not None)
|
||||||
|
|
||||||
@@ -249,7 +251,7 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config).to(torch_device)
|
||||||
outputs = model(**inputs, output_attentions=True)
|
outputs = model(**inputs, output_attentions=True)
|
||||||
self.assertTrue(outputs.attentions is not None)
|
self.assertTrue(outputs.attentions is not None)
|
||||||
|
|
||||||
@@ -381,7 +383,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
|
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
|
||||||
)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
# class_queries_logits
|
# class_queries_logits
|
||||||
class_queries_logits = outputs.class_queries_logits
|
class_queries_logits = outputs.class_queries_logits
|
||||||
@@ -392,7 +394,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
[3.6169e-02, -5.9025e00, -2.9313e00],
|
[3.6169e-02, -5.9025e00, -2.9313e00],
|
||||||
[1.0766e-04, -7.7630e00, -5.1263e00],
|
[1.0766e-04, -7.7630e00, -5.1263e00],
|
||||||
]
|
]
|
||||||
)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||||
|
|
||||||
def test_with_annotations_and_loss(self):
|
def test_with_annotations_and_loss(self):
|
||||||
@@ -406,7 +408,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
|||||||
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
|
||||||
],
|
],
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
).to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user