MaskFormer - enable return_dict in order to compile (#25052)

* Enable return_dict in order to compile

* Update tests
This commit is contained in:
amyeroberts
2023-07-26 16:23:30 +01:00
committed by GitHub
parent b914ec9847
commit 659829b6ae
2 changed files with 123 additions and 38 deletions

View File

@@ -1254,11 +1254,16 @@ class MaskFormerPixelDecoder(nn.Module):
self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs) self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1) self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput: def forward(
self, features: List[Tensor], output_hidden_states: bool = False, return_dict: bool = True
) -> MaskFormerPixelDecoderOutput:
fpn_features = self.fpn(features) fpn_features = self.fpn(features)
# we use the last feature map # we use the last feature map
last_feature_projected = self.mask_projection(fpn_features[-1]) last_feature_projected = self.mask_projection(fpn_features[-1])
if not return_dict:
return (last_feature_projected, tuple(fpn_features)) if output_hidden_states else (last_feature_projected,)
return MaskFormerPixelDecoderOutput( return MaskFormerPixelDecoderOutput(
last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else () last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
) )
@@ -1387,9 +1392,20 @@ class MaskFormerPixelLevelModule(nn.Module):
lateral_widths=feature_channels[:-1], lateral_widths=feature_channels[:-1],
) )
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput: def forward(
self, pixel_values: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> MaskFormerPixelLevelModuleOutput:
features = self.encoder(pixel_values).feature_maps features = self.encoder(pixel_values).feature_maps
decoder_output = self.decoder(features, output_hidden_states) decoder_output = self.decoder(features, output_hidden_states, return_dict=return_dict)
if not return_dict:
last_hidden_state = decoder_output[0]
outputs = (features[-1], last_hidden_state)
if output_hidden_states:
hidden_states = decoder_output[1]
outputs = outputs + (tuple(features),) + (hidden_states,)
return outputs
return MaskFormerPixelLevelModuleOutput( return MaskFormerPixelLevelModuleOutput(
# the last feature is actually the output from the last layer # the last feature is actually the output from the last layer
encoder_last_hidden_state=features[-1], encoder_last_hidden_state=features[-1],
@@ -1414,7 +1430,11 @@ class MaskFormerTransformerModule(nn.Module):
self.decoder = DetrDecoder(config=config.decoder_config) self.decoder = DetrDecoder(config=config.decoder_config)
def forward( def forward(
self, image_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False self,
image_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
return_dict: Optional[bool] = None,
) -> DetrDecoderOutput: ) -> DetrDecoderOutput:
if self.input_projection is not None: if self.input_projection is not None:
image_features = self.input_projection(image_features) image_features = self.input_projection(image_features)
@@ -1438,7 +1458,7 @@ class MaskFormerTransformerModule(nn.Module):
query_position_embeddings=queries_embeddings, query_position_embeddings=queries_embeddings,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=None, return_dict=return_dict,
) )
return decoder_output return decoder_output
@@ -1593,9 +1613,11 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
if pixel_mask is None: if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states) pixel_level_module_output = self.pixel_level_module(
image_features = pixel_level_module_output.encoder_last_hidden_state pixel_values, output_hidden_states, return_dict=return_dict
pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state )
image_features = pixel_level_module_output[0]
pixel_embeddings = pixel_level_module_output[1]
transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions) transformer_module_output = self.transformer_module(image_features, output_hidden_states, output_attentions)
queries = transformer_module_output.last_hidden_state queries = transformer_module_output.last_hidden_state
@@ -1606,9 +1628,9 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
hidden_states = None hidden_states = None
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states encoder_hidden_states = pixel_level_module_output[2]
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states pixel_decoder_hidden_states = pixel_level_module_output[3]
transformer_decoder_hidden_states = transformer_module_output.hidden_states transformer_decoder_hidden_states = transformer_module_output[1]
hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states
output = MaskFormerModelOutput( output = MaskFormerModelOutput(
@@ -1803,13 +1825,25 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs: MaskFormerModelOutput = self.model( raw_outputs = self.model(
pixel_values, pixel_values,
pixel_mask, pixel_mask,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
return_dict=True, return_dict=return_dict,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
# We need to have raw_outputs optionally be returned as a dict to use torch.compile. For backwards
# compatibility we convert to a dataclass for the rest of the model logic
outputs = MaskFormerModelOutput(
encoder_last_hidden_state=raw_outputs[0],
pixel_decoder_last_hidden_state=raw_outputs[1],
transformer_decoder_last_hidden_state=raw_outputs[2],
encoder_hidden_states=raw_outputs[3] if output_hidden_states else None,
pixel_decoder_hidden_states=raw_outputs[4] if output_hidden_states else None,
transformer_decoder_hidden_states=raw_outputs[5] if output_hidden_states else None,
hidden_states=raw_outputs[6] if output_hidden_states else None,
attentions=raw_outputs[-1] if output_attentions else None,
)
loss, loss_dict, auxiliary_logits = None, None, None loss, loss_dict, auxiliary_logits = None, None, None
@@ -1827,16 +1861,18 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
if not output_auxiliary_logits: if not output_auxiliary_logits:
auxiliary_logits = None auxiliary_logits = None
output = MaskFormerForInstanceSegmentationOutput( if not return_dict:
output = tuple(
v
for v in (loss, class_queries_logits, masks_queries_logits, auxiliary_logits, *outputs.values())
if v is not None
)
return output
return MaskFormerForInstanceSegmentationOutput(
loss=loss, loss=loss,
**outputs, **outputs,
class_queries_logits=class_queries_logits, class_queries_logits=class_queries_logits,
masks_queries_logits=masks_queries_logits, masks_queries_logits=masks_queries_logits,
auxiliary_logits=auxiliary_logits, auxiliary_logits=auxiliary_logits,
) )
if not return_dict:
output = tuple(v for v in output.values())
if loss is not None:
output = ((loss)) + output
return output

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch MaskFormer model. """ """ Testing suite for the PyTorch MaskFormer model. """
import copy
import inspect import inspect
import unittest import unittest
@@ -54,6 +55,8 @@ class MaskFormerModelTester:
max_size=32 * 6, max_size=32 * 6,
num_labels=4, num_labels=4,
mask_feature_size=32, mask_feature_size=32,
num_hidden_layers=2,
num_attention_heads=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@@ -65,6 +68,9 @@ class MaskFormerModelTester:
self.max_size = max_size self.max_size = max_size
self.num_labels = num_labels self.num_labels = num_labels
self.mask_feature_size = mask_feature_size self.mask_feature_size = mask_feature_size
# This is passed to the decoder config. We add it to the model tester here for testing
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to( pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
@@ -91,11 +97,12 @@ class MaskFormerModelTester:
), ),
decoder_config=DetrConfig( decoder_config=DetrConfig(
decoder_ffn_dim=64, decoder_ffn_dim=64,
decoder_layers=2, decoder_layers=self.num_hidden_layers,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=64, encoder_ffn_dim=64,
encoder_layers=2, encoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
num_queries=self.num_queries, num_queries=self.num_queries,
decoder_attention_heads=2,
d_model=self.mask_feature_size, d_model=self.mask_feature_size,
), ),
mask_feature_size=self.mask_feature_size, mask_feature_size=self.mask_feature_size,
@@ -196,6 +203,27 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self.model_tester = MaskFormerModelTester(self) self.model_tester = MaskFormerModelTester(self)
self.config_tester = ConfigTester(self, config_class=MaskFormerConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=MaskFormerConfig, has_text_modality=False)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in [MaskFormerForInstanceSegmentation]:
inputs_dict["mask_labels"] = torch.zeros(
(
self.model_tester.batch_size,
self.model_tester.num_labels,
self.model_tester.min_size,
self.model_tester.max_size,
),
dtype=torch.float32,
device=torch_device,
)
inputs_dict["class_labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.num_labels), dtype=torch.long, device=torch_device
)
return inputs_dict
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@@ -265,26 +293,47 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
self.model_tester.create_and_check_maskformer_model(config, **inputs, output_hidden_states=True) self.model_tester.create_and_check_maskformer_model(config, **inputs, output_hidden_states=True)
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config).to(torch_device) inputs_dict["output_attentions"] = True
outputs = model(**inputs, output_attentions=True) inputs_dict["output_hidden_states"] = False
self.assertTrue(outputs.attentions is not None) config.return_dict = True
def test_training(self):
if not self.model_tester.is_training:
return
# only MaskFormerForInstanceSegmentation has the loss
model_class = self.all_model_classes[1]
config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
loss = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels).loss # Check that output_attentions also work using config
loss.backward() del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
# encoder_hidden_states, pixel_decoder_hidden_states, transformer_decoder_hidden_states, hidden_states
added_hidden_states = 4
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
# only MaskFormerForInstanceSegmentation has the loss # only MaskFormerForInstanceSegmentation has the loss