From 0558914dff91963f0488bd28747cdd45e933e7a4 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Wed, 22 Mar 2023 07:35:47 +0300 Subject: [PATCH] Add MaskedImageModelingOutput (#22212) * Add MaskedImageModelingOutput --- src/transformers/modeling_outputs.py | 38 +++++++++++++++++ src/transformers/modeling_tf_outputs.py | 42 ++++++++++++++++++- src/transformers/models/deit/modeling_deit.py | 17 +++++--- .../models/deit/modeling_tf_deit.py | 18 ++++---- src/transformers/models/vit/modeling_vit.py | 17 +++++--- tests/models/deit/test_modeling_deit.py | 4 +- tests/models/deit/test_modeling_tf_deit.py | 4 +- tests/models/vit/test_modeling_vit.py | 4 +- 8 files changed, 116 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 4f7540d0ff..c69e426ab5 100755 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -1622,3 +1623,40 @@ class SampleTSPredictionOutput(ModelOutput): """ sequences: torch.FloatTensor = None + + +@dataclass +class MaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 0fed3e7851..f8148b1695 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple @@ -55,8 +56,8 @@ class TFBaseModelOutputWithNoAttention(ModelOutput): last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. """ @@ -949,3 +950,40 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor, ...]] = None + + +@dataclass +class TFMaskedImageModelingOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Reconstruction loss. + reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called + feature maps) of the model at the output of each stage. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[tf.Tensor] = None + reconstruction: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + @property + def logits(self): + warnings.warn( + "logits attribute is deprecated and will be removed in version 5 of Transformers." + " Please use the reconstruction attribute to retrieve the final output instead.", + FutureWarning, + ) + return self.reconstruction diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 2a37e3ebc1..b5d1aa28c9 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -26,7 +26,12 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedImageModelingOutput, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -592,7 +597,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): self.post_init() @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -601,7 +606,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedLMOutput]: + ) -> Union[tuple, MaskedImageModelingOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -627,7 +632,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 224, 224] ```""" @@ -670,9 +675,9 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedLMOutput( + return MaskedImageModelingOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformers/models/deit/modeling_tf_deit.py b/src/transformers/models/deit/modeling_tf_deit.py index 1bcfd07b38..558de60614 100644 --- a/src/transformers/models/deit/modeling_tf_deit.py +++ b/src/transformers/models/deit/modeling_tf_deit.py @@ -27,7 +27,7 @@ from ...modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPooling, TFImageClassifierOutput, - TFMaskedLMOutput, + TFMaskedImageModelingOutput, ) from ...modeling_tf_utils import ( TFPreTrainedModel, @@ -769,7 +769,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): @unpack_inputs @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) def call( self, pixel_values: Optional[tf.Tensor] = None, @@ -779,7 +779,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, - ) -> Union[tuple, TFMaskedLMOutput]: + ) -> Union[tuple, TFMaskedImageModelingOutput]: r""" bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -805,7 +805,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): >>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 224, 224] ```""" @@ -860,18 +860,20 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return TFMaskedLMOutput( + return TFMaskedImageModelingOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: + def serving_output(self, output: TFMaskedImageModelingOutput) -> TFMaskedImageModelingOutput: hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) + return TFMaskedImageModelingOutput( + reconstruction=output.reconstruction, hidden_states=hidden_states, attentions=attentions + ) @add_start_docstrings( diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 03a9212ff0..eba8278eaa 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -25,7 +25,12 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedImageModelingOutput, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -647,7 +652,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): self.post_init() @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -657,7 +662,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedLMOutput]: + ) -> Union[tuple, MaskedImageModelingOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -683,7 +688,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) - >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction >>> list(reconstructed_pixel_values.shape) [1, 3, 224, 224] ```""" @@ -727,9 +732,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedLMOutput( + return MaskedImageModelingOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index fbf6d7353b..1564b23aa6 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -145,7 +145,7 @@ class DeiTModelTester: model.eval() result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -156,7 +156,7 @@ class DeiTModelTester: pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size diff --git a/tests/models/deit/test_modeling_tf_deit.py b/tests/models/deit/test_modeling_tf_deit.py index c7c1fc8456..223d164d4a 100644 --- a/tests/models/deit/test_modeling_tf_deit.py +++ b/tests/models/deit/test_modeling_tf_deit.py @@ -130,7 +130,7 @@ class TFDeiTModelTester: model = TFDeiTForMaskedImageModeling(config=config) result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -139,7 +139,7 @@ class TFDeiTModelTester: pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index be509d460e..d7ae3c162f 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -134,7 +134,7 @@ class ViTModelTester: model.eval() result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -145,7 +145,7 @@ class ViTModelTester: pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size