minor fixes (#14026)

This commit is contained in:
Suraj Patil
2021-10-16 05:38:57 +05:30
committed by GitHub
parent f5af873617
commit 84ad6af49a
2 changed files with 11 additions and 8 deletions

View File

@@ -102,7 +102,8 @@ class CLIPVisionModelTester:
model = CLIPVisionModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
@@ -350,8 +351,9 @@ class CLIPTextModelTester:
model = CLIPTextModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
@@ -429,7 +431,8 @@ class CLIPModelTester:
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = CLIPModel(config).to(torch_device).eval()
result = model(input_ids, pixel_values, attention_mask)
with torch.no_grad():
result = model(input_ids, pixel_values, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)