Fix BeitForMaskedImageModeling (#13275)
* First pass * Fix docs of bool_masked_pos * Add integration script * Fix docstring * Add integration test for BeitForMaskedImageModeling * Remove file * Fix docs
This commit is contained in:
@@ -379,6 +379,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_masked_image_modeling_head(self):
|
||||
model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
# prepare bool_masked_pos
|
||||
bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 196, 8192))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2))
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head_imagenet_1k(self):
|
||||
model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user