From 9d99489f2f79b81fa9131c9299c236006dff94fb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 8 Jun 2022 18:33:18 +0530 Subject: [PATCH] Add TFData2VecVision for semantic segmentation (#17271) * feat: initial implementation of data2vec segmentation model in TF. * chore: minor corrections to make the segmenter work. * chore: removed unncessary files. * chore: add tests and other modifications. * fix: loss computation for segmentation. * chore: remove unused variable. * chore: formatting. * added a dummy adaptive pooling layer. * removed unnecessary file. * potentially add identifiers to layer names. * fix: layer naming. * chore: removed unnecessary print. * Skipping unneeded test * chore: add logging to debug tolerance. * fix: segmentation tests for tfdata2vecvision * chore: make style. * fix: layer names, assertion to be resolved. * Bumping test tolerance a bit * chore: bump the tol in PT test. Co-authored-by: matt --- docs/source/en/model_doc/data2vec.mdx | 7 +- src/transformers/__init__.py | 2 + src/transformers/modeling_tf_outputs.py | 37 ++ src/transformers/modeling_tf_pytorch_utils.py | 4 + .../models/auto/modeling_tf_auto.py | 19 + src/transformers/models/data2vec/__init__.py | 2 + .../data2vec/modeling_tf_data2vec_vision.py | 478 +++++++++++++++++- src/transformers/utils/dummy_tf_objects.py | 7 + .../data2vec/test_modeling_data2vec_vision.py | 4 + .../test_modeling_tf_data2vec_vision.py | 36 +- 10 files changed, 590 insertions(+), 6 deletions(-) diff --git a/docs/source/en/model_doc/data2vec.mdx b/docs/source/en/model_doc/data2vec.mdx index 3dbff94bdf..8623d64afe 100644 --- a/docs/source/en/model_doc/data2vec.mdx +++ b/docs/source/en/model_doc/data2vec.mdx @@ -42,7 +42,7 @@ Tips: This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten). -[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow. +[sayakpaul](https://github.com/sayakpaul) and [Rocketknight1](https://github.com/Rocketknight1) contributed Data2Vec for vision in TensorFlow. The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec). The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit). @@ -145,3 +145,8 @@ The original code for vision can be found [here](https://github.com/facebookrese [[autodoc]] TFData2VecVisionForImageClassification - call + +## TFData2VecVisionForSemanticSegmentation + +[[autodoc]] TFData2VecVisionForSemanticSegmentation + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e484600814..2b73d62aa8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2036,6 +2036,7 @@ else: _import_structure["models.data2vec"].extend( [ "TFData2VecVisionForImageClassification", + "TFData2VecVisionForSemanticSegmentation", "TFData2VecVisionModel", "TFData2VecVisionPreTrainedModel", ] @@ -4342,6 +4343,7 @@ if TYPE_CHECKING: ) from .models.data2vec import ( TFData2VecVisionForImageClassification, + TFData2VecVisionForSemanticSegmentation, TFData2VecVisionModel, TFData2VecVisionPreTrainedModel, ) diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 5c74236607..1e71556ec1 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -607,6 +607,43 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput): encoder_attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFSemanticSegmenterOutput(ModelOutput): + """ + Base class for outputs of semantic segmentation models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[tf.Tensor] = None + logits: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFMultipleChoiceModelOutput(ModelOutput): """ diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 59846a8925..8882771014 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -163,6 +163,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") + if "running_var" in key: + new_key = key.replace("running_var", "moving_variance") + if "running_mean" in key: + new_key = key.replace("running_mean", "moving_mean") if new_key: old_keys.append(key) new_keys.append(new_key) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 716fd2575b..9c889597e5 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -178,6 +178,13 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ] ) +TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Semantic Segmentation mapping + ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), + ] +) + TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"), @@ -365,6 +372,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) +TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES +) TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( @@ -440,6 +450,15 @@ TFAutoModelForImageClassification = auto_class_update( ) +class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + + +TF_AutoModelForSemanticSegmentation = auto_class_update( + TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation" +) + + class TFAutoModelForVision2Seq(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING diff --git a/src/transformers/models/data2vec/__init__.py b/src/transformers/models/data2vec/__init__.py index 2a92a620d4..794124575e 100644 --- a/src/transformers/models/data2vec/__init__.py +++ b/src/transformers/models/data2vec/__init__.py @@ -73,6 +73,7 @@ else: if is_tf_available(): _import_structure["modeling_tf_data2vec_vision"] = [ "TFData2VecVisionForImageClassification", + "TFData2VecVisionForSemanticSegmentation", "TFData2VecVisionModel", "TFData2VecVisionPreTrainedModel", ] @@ -127,6 +128,7 @@ if TYPE_CHECKING: if is_tf_available(): from .modeling_tf_data2vec_vision import ( TFData2VecVisionForImageClassification, + TFData2VecVisionForSemanticSegmentation, TFData2VecVisionModel, TFData2VecVisionPreTrainedModel, ) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index 618e66a10d..e7cc7d2449 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -17,7 +17,7 @@ import collections.abc import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -25,7 +25,12 @@ import tensorflow as tf from transformers.tf_utils import shape_list, stable_softmax from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFSemanticSegmenterOutput, + TFSequenceClassifierOutput, +) from ...modeling_tf_utils import ( TFModelInputType, TFPreTrainedModel, @@ -34,7 +39,13 @@ from ...modeling_tf_utils import ( keras_serializable, unpack_inputs, ) -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_data2vec_vision import Data2VecVisionConfig @@ -978,3 +989,464 @@ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TF hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class TFData2VecVisionConvModule(tf.keras.layers.Layer): + """ + A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution + layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + padding: str = "valid", + bias: bool = False, + dilation: Union[int, Tuple[int, int]] = 1, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.conv = tf.keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + padding=padding, + use_bias=bias, + dilation_rate=dilation, + name="conv", + ) + self.bn = tf.keras.layers.BatchNormalization(name="bn") + self.activation = tf.nn.relu + + def call(self, input: tf.Tensor) -> tf.Tensor: + output = self.conv(input) + output = self.bn(output) + output = self.activation(output) + return output + + +# Copied from: +# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb +class TFAdaptiveAvgPool1D(tf.keras.layers.Layer): + def __init__(self, output_dim, mode="dense", **kwargs): + super().__init__(**kwargs) + self.output_dim = output_dim + self.mode = mode + self.map = None + + def build(self, input_shape): + super().build(input_shape) + """We pre-compute the sparse matrix for the build() step once. The below code comes + from https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993.""" + + def get_kernels(ind, outd) -> List: + """Returns a List [(kernel_offset_start,kernel_length)] defining all the pooling kernels for a 1-D adaptive + pooling layer that takes an input of dimension `ind` and yields an output of dimension `outd`""" + + def start_index(a, b, c): + return math.floor((float(a) * float(c)) / b) + + def end_index(a, b, c): + return math.ceil((float(a + 1) * float(c)) / b) + + results = [] + for ow in range(outd): + start = start_index(ow, outd, ind) + end = end_index(ow, outd, ind) + sz = end - start + results.append((start, sz)) + return results + + in_dim = int(input_shape[-1]) + kernels = get_kernels(in_dim, self.output_dim) + sparse_map = np.zeros((in_dim, self.output_dim), dtype=np.float32) + for i, kernel in enumerate(kernels): + sparse_map[kernel[0] : kernel[0] + kernel[1], i] = 1 / kernel[1] + if self.mode == "dense": + self.map = tf.constant(sparse_map) + else: + self.map = tf.sparse.from_dense(sparse_map) + + def call(self, inputs): + if self.mode == "dense": + return inputs @ self.map + else: + input_dims = inputs.shape + input_matrix = tf.reshape(inputs, (-1, input_dims[-1])) + out = tf.sparse.sparse_dense_matmul(input_matrix, self.map) + return tf.reshape(out, input_dims[:-1].as_list() + [-1]) + + def get_config(self): + config = super().get_config() + config.update({"output_dim": self.output_dim, "mode": self.mode}) + return config + + +class TFAdaptiveAvgPool2D(tf.keras.layers.Layer): + def __init__(self, output_shape, mode="dense", **kwargs): + super().__init__(**kwargs) + self.mode = mode + self.h_pool = TFAdaptiveAvgPool1D(output_shape[0], mode=mode, name="h_pool") + self.w_pool = TFAdaptiveAvgPool1D(output_shape[1], mode=mode, name="w_pool") + + def call(self, inputs): + # Rearrange from NHWC -> NCHW + inputs = tf.transpose(inputs, perm=[0, 3, 1, 2]) + # Perform W-pooling + inputs = self.w_pool(inputs) + # Rearrange NCHW -> NCWH + inputs = tf.transpose(inputs, perm=[0, 1, 3, 2]) + # Perform H-pooling + inputs = self.h_pool(inputs) + # Rearrange from NCWH -> NHWC + inputs = tf.transpose(inputs, perm=[0, 3, 2, 1]) + return inputs + + def get_config(self): + config = super().get_config() + config.update({"mode": self.mode}) + return config + + +class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer): + """ + Pyramid Pooling Module (PPM) used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + channels (int): Channels after modules, before conv_seg. + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, pool_scales: Tuple[int, ...], channels: int, **kwargs) -> None: + super().__init__(**kwargs) + self.pool_scales = pool_scales + self.channels = channels + + self.layer_list = [] + for idx, pool_scale in enumerate(pool_scales): + pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale) + self.layer_list.append( + [ + TFAdaptiveAvgPool2D(output_shape=pool_scale), + TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"{idx}.1"), + ] + ) + + def call(self, x: tf.Tensor) -> List[tf.Tensor]: + ppm_outs = [] + inputs = x + + for ppm in self.layer_list: + for layer_module in ppm: + ppm_out = layer_module(x) + x = ppm_out + + upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear") + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +class TFData2VecVisionUperHead(tf.keras.layers.Layer): + """ + Unified Perceptual Parsing for Scene Understanding. This head is the implementation of + [UPerNet](https://arxiv.org/abs/1807.10221). + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None: + super().__init__(**kwargs) + + self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6) + self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768] + self.channels = config.hidden_size + self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier") + + # PSP Module + self.psp_modules = TFData2VecVisionPyramidPoolingModule(self.pool_scales, self.channels, name="psp_modules") + self.bottleneck = TFData2VecVisionConvModule(self.channels, kernel_size=3, padding="same", name="bottleneck") + # FPN Module + self.lateral_convs = [] + self.fpn_convs = [] + for idx, _ in enumerate(self.in_channels[:-1]): # skip the top layer + l_conv = TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}") + fpn_conv = TFData2VecVisionConvModule( + out_channels=self.channels, kernel_size=3, padding="same", name=f"fpn_convs.{idx}" + ) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = TFData2VecVisionConvModule( + out_channels=self.channels, kernel_size=3, padding="same", name="fpn_bottleneck" + ) + + def psp_forward(self, inputs): + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = tf.concat(psp_outs, axis=-1) + output = self.bottleneck(psp_outs) + + return output + + def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor: + # build laterals + laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)] + + laterals.append(self.psp_forward(encoder_hidden_states)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = shape_list(laterals[i - 1])[1:-1] + laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear") + + # build outputs + fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear") + fpn_outs = tf.concat(fpn_outs, axis=-1) + output = self.fpn_bottleneck(fpn_outs) + output = self.classifier(output) + + return output + + +class TFData2VecVisionFCNHead(tf.keras.layers.Layer): + """ + Fully Convolution Networks for Semantic Segmentation. This head is implemented from + [FCNNet](https://arxiv.org/abs/1411.4038). + + Args: + config (Data2VecVisionConfig): Configuration. + kernel_size (int): The kernel size for convs in the head. Default: 3. + dilation (int): The dilation rate for convs in the head. Default: 1. + + + Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation. + """ + + def __init__( + self, + config: Data2VecVisionConfig, + in_index: int = 2, + kernel_size: int = 3, + dilation: Union[int, Tuple[int, int]] = 1, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.in_channels = config.hidden_size + self.channels = config.auxiliary_channels + self.num_convs = config.auxiliary_num_convs + self.concat_input = config.auxiliary_concat_input + self.in_index = in_index + + convs = [] + convs.append( + TFData2VecVisionConvModule( + out_channels=self.channels, + kernel_size=kernel_size, + padding="same", + dilation=dilation, + name="convs.0", + ) + ) + for i in range(self.num_convs - 1): + convs.append( + TFData2VecVisionConvModule( + out_channels=self.channels, + kernel_size=kernel_size, + padding="same", + dilation=dilation, + name=f"conv_module_{i+2}", + ) + ) + if self.num_convs == 0: + self.convs = [tf.identity] + else: + self.convs = convs + if self.concat_input: + self.conv_cat = TFData2VecVisionConvModule( + out_channels=self.channels, kernel_size=kernel_size, padding="same", name="conv_cat" + ) + + self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier") + + def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor: + # just take the relevant feature maps + hidden_states = encoder_hidden_states[self.in_index] + output = hidden_states + for layer_module in self.convs: + output = layer_module(output) + if self.concat_input: + output = self.conv_cat(tf.concat([hidden_states, output], axis=-1)) + output = self.classifier(output) + return output + + +@add_start_docstrings( + """ + Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DATA2VEC_VISION_START_DOCSTRING, +) +class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel): + def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision") + + # FPNs + self.fpn1 = [ + tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"), + tf.keras.layers.BatchNormalization(name="fpn1.1"), + tf.keras.layers.Activation("gelu"), + tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"), + ] + self.fpn2 = [tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")] + + self.fpn3 = tf.identity + self.fpn4 = tf.keras.layers.MaxPool2D(pool_size=2, strides=2) + + # Semantic segmentation head(s) + self.decode_head = TFData2VecVisionUperHead(config, name="decode_head") + self.auxiliary_head = ( + TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None + ) + + def compute_loss(self, logits, auxiliary_logits, labels): + # upsample logits to the images' original size + if len(shape_list(labels)) > 3: + label_interp_shape = shape_list(labels)[1:-1] + else: + label_interp_shape = shape_list(labels)[-2:] + + upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") + if auxiliary_logits is not None: + upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear") + # compute weighted loss + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + + # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics. + # Utility to mask the index to ignore during computing the loss. + def masked_loss(real, pred): + mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index)) + loss_ = loss_fct(real, pred) + mask = tf.cast(mask, dtype=loss_.dtype) + loss_ *= mask + + return tf.reduce_sum(loss_) / tf.reduce_sum(mask) + + main_loss = masked_loss(labels, upsampled_logits) + auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits) + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + return loss + + @unpack_inputs + @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + head_mask: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TFSemanticSegmenterOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, TFData2VecVisionForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/data2vec-vision-base") + >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.data2vec_vision( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features, and reshape + # note that we do +1 as the encoder_hidden_states also includes the initial embeddings + features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] + batch_size = shape_list(pixel_values)[0] + patch_resolution = self.config.image_size // self.config.patch_size + + def reshape_features(x): + x = tf.reshape(x, (batch_size, patch_resolution, patch_resolution, -1)) + return x + + features = [reshape_features(x[:, 1:, :]) for x in features] + + # apply FPNs + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for module in ops[0]: + features[0] = module(features[0]) + features[1] = ops[1][0](features[1]) + for i in range(len(features[2:])): + features[i + 2] = ops[i + 2](features[i + 2]) + + logits = self.decode_head(features) + # Tranpose the logits to maintain consistency in the output formats. + transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2]) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(features) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + else: + loss = self.compute_loss(logits, auxiliary_logits, labels) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSemanticSegmenterOutput( + loss=loss, + logits=transposed_logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 00965e2f0a..4eb40113e7 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -759,6 +759,13 @@ class TFData2VecVisionForImageClassification(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFData2VecVisionForSemanticSegmentation(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFData2VecVisionModel(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 2dc9f1e45e..8966b90997 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -389,6 +389,10 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase): check_hidden_states_output(inputs_dict, config, model_class) + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None): + # We override with a slightly higher tol value, as semseg models tend to diverge a bit more + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py index 17b02d037c..eb085af0d8 100644 --- a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py @@ -31,7 +31,11 @@ from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_te if is_tf_available(): import tensorflow as tf - from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel + from transformers import ( + TFData2VecVisionForImageClassification, + TFData2VecVisionForSemanticSegmentation, + TFData2VecVisionModel, + ) from transformers.models.data2vec.modeling_tf_data2vec_vision import ( TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, ) @@ -142,6 +146,18 @@ class TFData2VecVisionModelTester: result = model(pixel_values, labels=labels, training=False) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels): + config.num_labels = self.num_labels + model = TFData2VecVisionForSemanticSegmentation(config) + result = model(pixel_values, training=False) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2) + ) + result = model(pixel_values, labels=pixel_labels) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2) + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels, pixel_labels = config_and_inputs @@ -162,7 +178,11 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase): attention_mask and seq_length. """ - all_model_classes = (TFData2VecVisionModel, TFData2VecVisionForImageClassification) if is_tf_available() else () + all_model_classes = ( + (TFData2VecVisionModel, TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation) + if is_tf_available() + else () + ) test_pruning = False test_onnx = False @@ -208,6 +228,14 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_image_segmentation(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs) + + @unittest.skip("Test was written for TF 1.x and isn't really relevant here") + def test_compile_tf_model(self): + pass + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True @@ -354,6 +382,10 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase): val_loss2 = history2.history["val_loss"][0] self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None): + # We override with a slightly higher tol value, as semseg models tend to diverge a bit more + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) + # Overriding this method since the base method won't be compatible with Data2VecVision. def test_loss_computation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()