This commit is contained in:
Francesco Saverio Zuppichini
2022-03-09 15:51:56 +01:00
committed by GitHub
parent 38bce1d4cf
commit 1e8f37992f
2 changed files with 7 additions and 8 deletions

View File

@@ -2313,16 +2313,16 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
) )
queries = transformer_module_output.last_hidden_state queries = transformer_module_output.last_hidden_state
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
hidden_states = None
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states encoder_hidden_states = pixel_level_module_output.encoder_hidden_states
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states
transformer_decoder_hidden_states = transformer_module_output.hidden_states transformer_decoder_hidden_states = transformer_module_output.hidden_states
hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states
else:
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
hidden_states = None
output = MaskFormerModelOutput( output = MaskFormerModelOutput(
encoder_last_hidden_state=image_features, encoder_last_hidden_state=image_features,
@@ -2463,7 +2463,6 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
>>> # you can pass them to feature_extractor for postprocessing >>> # you can pass them to feature_extractor for postprocessing
>>> output = feature_extractor.post_process_segmentation(outputs) >>> output = feature_extractor.post_process_segmentation(outputs)
>>> output = feature_extractor.post_process_semantic_segmentation(outputs) >>> output = feature_extractor.post_process_semantic_segmentation(outputs)
>>> output = feature_extractor.post_process_panoptic_segmentation(outputs) >>> output = feature_extractor.post_process_panoptic_segmentation(outputs)
``` ```
""" """
@@ -2477,7 +2476,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
outputs: MaskFormerModelOutput = self.model( outputs: MaskFormerModelOutput = self.model(
pixel_values, pixel_values,
pixel_mask, pixel_mask,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
return_dict=True, return_dict=True,
output_attentions=output_attentions, output_attentions=output_attentions,
) )

View File

@@ -139,7 +139,7 @@ class MaskFormerModelTester:
def comm_check_on_output(result): def comm_check_on_output(result):
# let's still check that all the required stuff is there # let's still check that all the required stuff is there
self.parent.assertTrue(result.transformer_decoder_hidden_states is not None) self.parent.assertTrue(result.transformer_decoder_last_hidden_state is not None)
self.parent.assertTrue(result.pixel_decoder_last_hidden_state is not None) self.parent.assertTrue(result.pixel_decoder_last_hidden_state is not None)
self.parent.assertTrue(result.encoder_last_hidden_state is not None) self.parent.assertTrue(result.encoder_last_hidden_state is not None)
# okay, now we need to check the logits shape # okay, now we need to check the logits shape