adding positional encoder changes and tests (#32600)
* adding positional encoder changes and tests * adding ruff suggestions * changes added by python utils/check_copies.py --fix_and_overwrite * removing pos_encoding added by script * adding interpolation to clipseg * formatting * adding further testing to altclip and better documentation to kosmos2 * skipping test_inputs_embeds_matches_input_ids_with_generate in git model * fixing clipseg comment suggestions * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * fixing bridgetower test * fixing altclip tensor output POS test * adding ruff formatting * fixing several tests * formatting with ruff * adding positional encoder changes and tests * adding ruff suggestions * changes added by python utils/check_copies.py --fix_and_overwrite * removing pos_encoding added by script * adding interpolation to clipseg * formatting * adding further testing to altclip and better documentation to kosmos2 * skipping test_inputs_embeds_matches_input_ids_with_generate in git model * fixing clipseg comment suggestions * fixing bridgetower test * fixing altclip tensor output POS test * adding ruff formatting * fixing several tests * formatting with ruff * adding right pretrained model * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * fixing test_inference_image_segmentation * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * fixing test_inference_interpolate_pos_encoding for the git model as there is no vision_model_output * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * adding ruff formatting * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * adding new interpolate_pos_encoding function * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * fixing interpolate_POS funciton * adapting output tensor in teests * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * modifying output tensor * [run_slow] altclip, bridgetower, chinese_clip, clip, clipseg, git, kosmos2, x_clip * adding the correct tensor * [run_slow] clipseg * fixing spaces * [run_slow] clipseg * [run_slow] clipseg --------- Co-authored-by: Manuel Sanchez Hernandez <manuel.sanchez.hernandez@schibsted.com>
This commit is contained in:
@@ -614,3 +614,38 @@ class GitModelIntegrationTest(unittest.TestCase):
|
||||
generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(generated_captions, ["two cats sleeping on a pink blanket next to remotes."] * 2)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
# CLIP family models have an `interpolate_pos_encoding` argument in their forward method,
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
model = GitModel.from_pretrained("microsoft/git-base").to(torch_device)
|
||||
|
||||
processor = GitProcessor.from_pretrained(
|
||||
"microsoft/git-base", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180}
|
||||
)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# interpolate_pos_encodiung false should return value error
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
with torch.no_grad():
|
||||
model(**inputs, interpolate_pos_encoding=False)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 130, 768))
|
||||
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.0296, 2.5960, 0.8703], [1.7027, 1.3302, -0.4543], [-1.4932, -0.1084, 0.0502]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
Reference in New Issue
Block a user