🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 (#34393)

* fix(Mask2Former): torch export

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* revert level_start_index and create a level_start_index_list

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* Add a comment to explain the level_start_index_list

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* Address comment

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* add torch.export.export test

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* rename arg

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* remove spatial_shapes

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* Use the version check from pytorch_utils

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

* [run_slow] mask2former

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>

---------

Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>
This commit is contained in:
Phillip Kuznetsov
2024-11-19 07:44:53 -08:00
committed by GitHub
parent 581524389a
commit 5fa4f64605
2 changed files with 63 additions and 25 deletions

View File

@@ -926,7 +926,7 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None,
reference_points=None, reference_points=None,
spatial_shapes=None, spatial_shapes_list=None,
level_start_index=None, level_start_index=None,
output_attentions: bool = False, output_attentions: bool = False,
): ):
@@ -936,7 +936,8 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
batch_size, num_queries, _ = hidden_states.shape batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError( raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states" "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
) )
@@ -957,7 +958,11 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
) )
# batch_size, num_queries, n_heads, n_levels, n_points, 2 # batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2: if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) offset_normalizer = torch.tensor(
[[shape[1], shape[0]] for shape in spatial_shapes_list],
dtype=torch.long,
device=reference_points.device,
)
sampling_locations = ( sampling_locations = (
reference_points[:, :, None, :, None, :] reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
@@ -970,7 +975,7 @@ class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
else: else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights)
output = self.output_proj(output) output = self.output_proj(output)
return output, attention_weights return output, attention_weights
@@ -1001,7 +1006,7 @@ class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None, position_embeddings: torch.Tensor = None,
reference_points=None, reference_points=None,
spatial_shapes=None, spatial_shapes_list=None,
level_start_index=None, level_start_index=None,
output_attentions: bool = False, output_attentions: bool = False,
): ):
@@ -1015,8 +1020,8 @@ class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
Position embeddings, to be added to `hidden_states`. Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*): reference_points (`torch.FloatTensor`, *optional*):
Reference points. Reference points.
spatial_shapes (`torch.LongTensor`, *optional*): spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps. Spatial shapes of the backbone feature maps as a list of tuples.
level_start_index (`torch.LongTensor`, *optional*): level_start_index (`torch.LongTensor`, *optional*):
Level start index. Level start index.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
@@ -1033,7 +1038,7 @@ class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
reference_points=reference_points, reference_points=reference_points,
spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index, level_start_index=level_start_index,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -1086,13 +1091,13 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
) )
@staticmethod @staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device): def get_reference_points(spatial_shapes_list, valid_ratios, device):
""" """
Get reference points for each feature map. Used in decoder. Get reference points for each feature map. Used in decoder.
Args: Args:
spatial_shapes (`torch.LongTensor`): spatial_shapes_list (`list` of `tuple`):
Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`. Spatial shapes of the backbone feature maps as a list of tuples.
valid_ratios (`torch.FloatTensor`): valid_ratios (`torch.FloatTensor`):
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`. Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
device (`torch.device`): device (`torch.device`):
@@ -1101,7 +1106,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
""" """
reference_points_list = [] reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes): for lvl, (height, width) in enumerate(spatial_shapes_list):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
@@ -1122,7 +1127,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
inputs_embeds=None, inputs_embeds=None,
attention_mask=None, attention_mask=None,
position_embeddings=None, position_embeddings=None,
spatial_shapes=None, spatial_shapes_list=None,
level_start_index=None, level_start_index=None,
valid_ratios=None, valid_ratios=None,
output_attentions=None, output_attentions=None,
@@ -1140,8 +1145,8 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer. Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): spatial_shapes_list (`list` of `tuple`):
Spatial shapes of each feature map. Spatial shapes of each feature map as a list of tuples.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map. Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
@@ -1162,7 +1167,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds hidden_states = inputs_embeds
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device)
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
@@ -1176,7 +1181,7 @@ class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
attention_mask, attention_mask,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
reference_points=reference_points, reference_points=reference_points,
spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index, level_start_index=level_start_index,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
@@ -1302,9 +1307,9 @@ class Mask2FormerPixelDecoder(nn.Module):
] ]
# Prepare encoder inputs (by flattening) # Prepare encoder inputs (by flattening)
spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1) input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device) spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device)
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)
position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings] position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
@@ -1320,7 +1325,7 @@ class Mask2FormerPixelDecoder(nn.Module):
inputs_embeds=input_embeds_flat, inputs_embeds=input_embeds_flat,
attention_mask=masks_flat, attention_mask=masks_flat,
position_embeddings=level_pos_embed_flat, position_embeddings=level_pos_embed_flat,
spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index, level_start_index=level_start_index,
valid_ratios=valid_ratios, valid_ratios=valid_ratios,
output_attentions=output_attentions, output_attentions=output_attentions,
@@ -1331,18 +1336,23 @@ class Mask2FormerPixelDecoder(nn.Module):
last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = encoder_outputs.last_hidden_state
batch_size = last_hidden_state.shape[0] batch_size = last_hidden_state.shape[0]
# We compute level_start_index_list separately from the tensor version level_start_index
# to avoid iterating over a tensor which breaks torch.compile/export.
level_start_index_list = [0]
for height, width in spatial_shapes_list[:-1]:
level_start_index_list.append(level_start_index_list[-1] + height * width)
split_sizes = [None] * self.num_feature_levels split_sizes = [None] * self.num_feature_levels
for i in range(self.num_feature_levels): for i in range(self.num_feature_levels):
if i < self.num_feature_levels - 1: if i < self.num_feature_levels - 1:
split_sizes[i] = level_start_index[i + 1] - level_start_index[i] split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i]
else: else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i]
encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1) encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)
# Compute final features # Compute final features
outputs = [ outputs = [
x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1])
for i, x in enumerate(encoder_output) for i, x in enumerate(encoder_output)
] ]
@@ -1876,7 +1886,9 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
else: else:
level_index = idx % self.num_feature_levels level_index = idx % self.num_feature_levels
attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
attention_mask = attention_mask * where.unsqueeze(-1)
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,

View File

@@ -20,6 +20,7 @@ import numpy as np
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import ( from transformers.testing_utils import (
require_timm, require_timm,
require_torch, require_torch,
@@ -481,3 +482,28 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs) outputs = model(**inputs)
self.assertTrue(outputs.loss is not None) self.assertTrue(outputs.loss is not None)
def test_export(self):
if not is_torch_greater_or_equal_than_2_4:
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device)
exported_program = torch.export.export(
model,
args=(inputs["pixel_values"], inputs["pixel_mask"]),
strict=True,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"])
self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE)
)
self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE)
)