[CLIPSeg] Make interpolate_pos_encoding default to True (#34419)
* Remove interpolate_pos_encoding * Make fixup * Make interpolate_pos_encoding default to True * Reuse existing interpolation * Add integration test
This commit is contained in:
@@ -205,7 +205,7 @@ class CLIPSegVisionEmbeddings(nn.Module):
|
|||||||
|
|
||||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=True) -> torch.Tensor:
|
||||||
batch_size, _, height, width = pixel_values.shape
|
batch_size, _, height, width = pixel_values.shape
|
||||||
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -535,7 +535,7 @@ CLIPSEG_VISION_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
@@ -574,7 +574,7 @@ CLIPSEG_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
@@ -845,14 +845,13 @@ class CLIPSegVisionTransformer(nn.Module):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
|
||||||
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor],
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = False,
|
interpolate_pos_encoding: Optional[bool] = True,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -864,9 +863,6 @@ class CLIPSegVisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if pixel_values is None:
|
|
||||||
raise ValueError("You have to specify pixel_values")
|
|
||||||
|
|
||||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||||
hidden_states = self.pre_layrnorm(hidden_states)
|
hidden_states = self.pre_layrnorm(hidden_states)
|
||||||
|
|
||||||
@@ -912,7 +908,7 @@ class CLIPSegVisionModel(CLIPSegPreTrainedModel):
|
|||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: Optional[bool] = False,
|
interpolate_pos_encoding: Optional[bool] = True,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1035,7 +1031,7 @@ class CLIPSegModel(CLIPSegPreTrainedModel):
|
|||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = True,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""
|
r"""
|
||||||
@@ -1091,7 +1087,7 @@ class CLIPSegModel(CLIPSegPreTrainedModel):
|
|||||||
return_loss: Optional[bool] = None,
|
return_loss: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = True,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, CLIPSegOutput]:
|
) -> Union[Tuple, CLIPSegOutput]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1397,7 +1393,7 @@ class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = True,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, CLIPSegOutput]:
|
) -> Union[Tuple, CLIPSegOutput]:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -796,7 +796,7 @@ class CLIPSegModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
# verify the predicted masks
|
# verify the predicted masks
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -804,7 +804,7 @@ class CLIPSegModelIntegrationTest(unittest.TestCase):
|
|||||||
torch.Size((3, 352, 352)),
|
torch.Size((3, 352, 352)),
|
||||||
)
|
)
|
||||||
expected_masks_slice = torch.tensor(
|
expected_masks_slice = torch.tensor(
|
||||||
[[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]]
|
[[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user