🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 (#34393)
* fix(Mask2Former): torch export Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * revert level_start_index and create a level_start_index_list Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Add a comment to explain the level_start_index_list Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Address comment Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * add torch.export.export test Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * rename arg Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * remove spatial_shapes Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Use the version check from pytorch_utils Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * [run_slow] mask2former Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> --------- Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>
This commit is contained in:
committed by
GitHub
parent
581524389a
commit
5fa4f64605
@@ -926,7 +926,7 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
|
|||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
position_embeddings: Optional[torch.Tensor] = None,
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
@@ -936,7 +936,8 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
|
|||||||
|
|
||||||
batch_size, num_queries, _ = hidden_states.shape
|
batch_size, num_queries, _ = hidden_states.shape
|
||||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
||||||
|
if total_elements != sequence_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||||
)
|
)
|
||||||
@@ -957,7 +958,11 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
# batch_size, num_queries, n_heads, n_levels, n_points, 2
|
# batch_size, num_queries, n_heads, n_levels, n_points, 2
|
||||||
if reference_points.shape[-1] == 2:
|
if reference_points.shape[-1] == 2:
|
||||||
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
offset_normalizer = torch.tensor(
|
||||||
|
[[shape[1], shape[0]] for shape in spatial_shapes_list],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=reference_points.device,
|
||||||
|
)
|
||||||
sampling_locations = (
|
sampling_locations = (
|
||||||
reference_points[:, :, None, :, None, :]
|
reference_points[:, :, None, :, None, :]
|
||||||
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
||||||
@@ -970,7 +975,7 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights)
|
||||||
output = self.output_proj(output)
|
output = self.output_proj(output)
|
||||||
|
|
||||||
return output, attention_weights
|
return output, attention_weights
|
||||||
@@ -1001,7 +1006,7 @@ class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor = None,
|
position_embeddings: torch.Tensor = None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
@@ -1015,8 +1020,8 @@ class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
|
|||||||
Position embeddings, to be added to `hidden_states`.
|
Position embeddings, to be added to `hidden_states`.
|
||||||
reference_points (`torch.FloatTensor`, *optional*):
|
reference_points (`torch.FloatTensor`, *optional*):
|
||||||
Reference points.
|
Reference points.
|
||||||
spatial_shapes (`torch.LongTensor`, *optional*):
|
spatial_shapes_list (`list` of `tuple`):
|
||||||
Spatial shapes of the backbone feature maps.
|
Spatial shapes of the backbone feature maps as a list of tuples.
|
||||||
level_start_index (`torch.LongTensor`, *optional*):
|
level_start_index (`torch.LongTensor`, *optional*):
|
||||||
Level start index.
|
Level start index.
|
||||||
output_attentions (`bool`, *optional*):
|
output_attentions (`bool`, *optional*):
|
||||||
@@ -1033,7 +1038,7 @@ class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
|
|||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
reference_points=reference_points,
|
reference_points=reference_points,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -1086,13 +1091,13 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_reference_points(spatial_shapes, valid_ratios, device):
|
def get_reference_points(spatial_shapes_list, valid_ratios, device):
|
||||||
"""
|
"""
|
||||||
Get reference points for each feature map. Used in decoder.
|
Get reference points for each feature map. Used in decoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
spatial_shapes (`torch.LongTensor`):
|
spatial_shapes_list (`list` of `tuple`):
|
||||||
Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`.
|
Spatial shapes of the backbone feature maps as a list of tuples.
|
||||||
valid_ratios (`torch.FloatTensor`):
|
valid_ratios (`torch.FloatTensor`):
|
||||||
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
|
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1101,7 +1106,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
|||||||
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
|
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
|
||||||
"""
|
"""
|
||||||
reference_points_list = []
|
reference_points_list = []
|
||||||
for lvl, (height, width) in enumerate(spatial_shapes):
|
for lvl, (height, width) in enumerate(spatial_shapes_list):
|
||||||
ref_y, ref_x = torch.meshgrid(
|
ref_y, ref_x = torch.meshgrid(
|
||||||
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
||||||
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
||||||
@@ -1122,7 +1127,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=None,
|
position_embeddings=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
valid_ratios=None,
|
valid_ratios=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -1140,8 +1145,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
|||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
Position embeddings that are added to the queries and keys in each self-attention layer.
|
Position embeddings that are added to the queries and keys in each self-attention layer.
|
||||||
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
spatial_shapes_list (`list` of `tuple`):
|
||||||
Spatial shapes of each feature map.
|
Spatial shapes of each feature map as a list of tuples.
|
||||||
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
||||||
Starting index of each feature map.
|
Starting index of each feature map.
|
||||||
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
||||||
@@ -1162,7 +1167,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
|
reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device)
|
||||||
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
@@ -1176,7 +1181,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
reference_points=reference_points,
|
reference_points=reference_points,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -1302,9 +1307,9 @@ class Mask2FormerPixelDecoder(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Prepare encoder inputs (by flattening)
|
# Prepare encoder inputs (by flattening)
|
||||||
spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
|
spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
|
||||||
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
|
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
|
||||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device)
|
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device)
|
||||||
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)
|
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)
|
||||||
|
|
||||||
position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
|
position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
|
||||||
@@ -1320,7 +1325,7 @@ class Mask2FormerPixelDecoder(nn.Module):
|
|||||||
inputs_embeds=input_embeds_flat,
|
inputs_embeds=input_embeds_flat,
|
||||||
attention_mask=masks_flat,
|
attention_mask=masks_flat,
|
||||||
position_embeddings=level_pos_embed_flat,
|
position_embeddings=level_pos_embed_flat,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
valid_ratios=valid_ratios,
|
valid_ratios=valid_ratios,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -1331,18 +1336,23 @@ class Mask2FormerPixelDecoder(nn.Module):
|
|||||||
last_hidden_state = encoder_outputs.last_hidden_state
|
last_hidden_state = encoder_outputs.last_hidden_state
|
||||||
batch_size = last_hidden_state.shape[0]
|
batch_size = last_hidden_state.shape[0]
|
||||||
|
|
||||||
|
# We compute level_start_index_list separately from the tensor version level_start_index
|
||||||
|
# to avoid iterating over a tensor which breaks torch.compile/export.
|
||||||
|
level_start_index_list = [0]
|
||||||
|
for height, width in spatial_shapes_list[:-1]:
|
||||||
|
level_start_index_list.append(level_start_index_list[-1] + height * width)
|
||||||
split_sizes = [None] * self.num_feature_levels
|
split_sizes = [None] * self.num_feature_levels
|
||||||
for i in range(self.num_feature_levels):
|
for i in range(self.num_feature_levels):
|
||||||
if i < self.num_feature_levels - 1:
|
if i < self.num_feature_levels - 1:
|
||||||
split_sizes[i] = level_start_index[i + 1] - level_start_index[i]
|
split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i]
|
||||||
else:
|
else:
|
||||||
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
|
split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i]
|
||||||
|
|
||||||
encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)
|
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)
|
||||||
|
|
||||||
# Compute final features
|
# Compute final features
|
||||||
outputs = [
|
outputs = [
|
||||||
x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1])
|
x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1])
|
||||||
for i, x in enumerate(encoder_output)
|
for i, x in enumerate(encoder_output)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1876,7 +1886,9 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
level_index = idx % self.num_feature_levels
|
level_index = idx % self.num_feature_levels
|
||||||
|
|
||||||
attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
|
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
|
||||||
|
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
|
||||||
|
attention_mask = attention_mask * where.unsqueeze(-1)
|
||||||
|
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
|
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_timm,
|
require_timm,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -481,3 +482,28 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
|
|||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
self.assertTrue(outputs.loss is not None)
|
self.assertTrue(outputs.loss is not None)
|
||||||
|
|
||||||
|
def test_export(self):
|
||||||
|
if not is_torch_greater_or_equal_than_2_4:
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = image_processor(image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
exported_program = torch.export.export(
|
||||||
|
model,
|
||||||
|
args=(inputs["pixel_values"], inputs["pixel_mask"]),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
eager_outputs = model(**inputs)
|
||||||
|
exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"])
|
||||||
|
self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE)
|
||||||
|
)
|
||||||
|
self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user