Add TFData2VecVision for semantic segmentation (#17271)
* feat: initial implementation of data2vec segmentation model in TF. * chore: minor corrections to make the segmenter work. * chore: removed unncessary files. * chore: add tests and other modifications. * fix: loss computation for segmentation. * chore: remove unused variable. * chore: formatting. * added a dummy adaptive pooling layer. * removed unnecessary file. * potentially add identifiers to layer names. * fix: layer naming. * chore: removed unnecessary print. * Skipping unneeded test * chore: add logging to debug tolerance. * fix: segmentation tests for tfdata2vecvision * chore: make style. * fix: layer names, assertion to be resolved. * Bumping test tolerance a bit * chore: bump the tol in PT test. Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -42,7 +42,7 @@ Tips:
|
|||||||
|
|
||||||
|
|
||||||
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
|
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
|
||||||
[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow.
|
[sayakpaul](https://github.com/sayakpaul) and [Rocketknight1](https://github.com/Rocketknight1) contributed Data2Vec for vision in TensorFlow.
|
||||||
|
|
||||||
The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
|
The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
|
||||||
The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).
|
The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).
|
||||||
@@ -145,3 +145,8 @@ The original code for vision can be found [here](https://github.com/facebookrese
|
|||||||
|
|
||||||
[[autodoc]] TFData2VecVisionForImageClassification
|
[[autodoc]] TFData2VecVisionForImageClassification
|
||||||
- call
|
- call
|
||||||
|
|
||||||
|
## TFData2VecVisionForSemanticSegmentation
|
||||||
|
|
||||||
|
[[autodoc]] TFData2VecVisionForSemanticSegmentation
|
||||||
|
- call
|
||||||
|
|||||||
@@ -2036,6 +2036,7 @@ else:
|
|||||||
_import_structure["models.data2vec"].extend(
|
_import_structure["models.data2vec"].extend(
|
||||||
[
|
[
|
||||||
"TFData2VecVisionForImageClassification",
|
"TFData2VecVisionForImageClassification",
|
||||||
|
"TFData2VecVisionForSemanticSegmentation",
|
||||||
"TFData2VecVisionModel",
|
"TFData2VecVisionModel",
|
||||||
"TFData2VecVisionPreTrainedModel",
|
"TFData2VecVisionPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -4342,6 +4343,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.data2vec import (
|
from .models.data2vec import (
|
||||||
TFData2VecVisionForImageClassification,
|
TFData2VecVisionForImageClassification,
|
||||||
|
TFData2VecVisionForSemanticSegmentation,
|
||||||
TFData2VecVisionModel,
|
TFData2VecVisionModel,
|
||||||
TFData2VecVisionPreTrainedModel,
|
TFData2VecVisionPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -607,6 +607,43 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TFSemanticSegmenterOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of semantic segmentation models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||||
|
Classification (or regression if config.num_labels==1) loss.
|
||||||
|
logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
|
||||||
|
Classification scores for each pixel.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
|
||||||
|
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
|
||||||
|
original image size as post-processing. You should always check your logits shape and resize as needed.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
|
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
|
||||||
|
the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[tf.Tensor] = None
|
||||||
|
logits: tf.Tensor = None
|
||||||
|
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TFMultipleChoiceModelOutput(ModelOutput):
|
class TFMultipleChoiceModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -163,6 +163,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
new_key = key.replace("gamma", "weight")
|
new_key = key.replace("gamma", "weight")
|
||||||
if "beta" in key:
|
if "beta" in key:
|
||||||
new_key = key.replace("beta", "bias")
|
new_key = key.replace("beta", "bias")
|
||||||
|
if "running_var" in key:
|
||||||
|
new_key = key.replace("running_var", "moving_variance")
|
||||||
|
if "running_mean" in key:
|
||||||
|
new_key = key.replace("running_mean", "moving_mean")
|
||||||
if new_key:
|
if new_key:
|
||||||
old_keys.append(key)
|
old_keys.append(key)
|
||||||
new_keys.append(new_key)
|
new_keys.append(new_key)
|
||||||
|
|||||||
@@ -178,6 +178,13 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Semantic Segmentation mapping
|
||||||
|
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
||||||
@@ -365,6 +372,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|||||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||||
)
|
)
|
||||||
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
||||||
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||||
@@ -440,6 +450,15 @@ TFAutoModelForImageClassification = auto_class_update(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
||||||
|
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
TF_AutoModelForSemanticSegmentation = auto_class_update(
|
||||||
|
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
|
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
|
||||||
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ else:
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_data2vec_vision"] = [
|
_import_structure["modeling_tf_data2vec_vision"] = [
|
||||||
"TFData2VecVisionForImageClassification",
|
"TFData2VecVisionForImageClassification",
|
||||||
|
"TFData2VecVisionForSemanticSegmentation",
|
||||||
"TFData2VecVisionModel",
|
"TFData2VecVisionModel",
|
||||||
"TFData2VecVisionPreTrainedModel",
|
"TFData2VecVisionPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -127,6 +128,7 @@ if TYPE_CHECKING:
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_data2vec_vision import (
|
from .modeling_tf_data2vec_vision import (
|
||||||
TFData2VecVisionForImageClassification,
|
TFData2VecVisionForImageClassification,
|
||||||
|
TFData2VecVisionForSemanticSegmentation,
|
||||||
TFData2VecVisionModel,
|
TFData2VecVisionModel,
|
||||||
TFData2VecVisionPreTrainedModel,
|
TFData2VecVisionPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import collections.abc
|
import collections.abc
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -25,7 +25,12 @@ import tensorflow as tf
|
|||||||
from transformers.tf_utils import shape_list, stable_softmax
|
from transformers.tf_utils import shape_list, stable_softmax
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
from ...modeling_tf_outputs import (
|
||||||
|
TFBaseModelOutput,
|
||||||
|
TFBaseModelOutputWithPooling,
|
||||||
|
TFSemanticSegmenterOutput,
|
||||||
|
TFSequenceClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
TFModelInputType,
|
TFModelInputType,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
@@ -34,7 +39,13 @@ from ...modeling_tf_utils import (
|
|||||||
keras_serializable,
|
keras_serializable,
|
||||||
unpack_inputs,
|
unpack_inputs,
|
||||||
)
|
)
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from .configuration_data2vec_vision import Data2VecVisionConfig
|
from .configuration_data2vec_vision import Data2VecVisionConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -978,3 +989,464 @@ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TF
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFData2VecVisionConvModule(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
||||||
|
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
padding: str = "valid",
|
||||||
|
bias: bool = False,
|
||||||
|
dilation: Union[int, Tuple[int, int]] = 1,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.conv = tf.keras.layers.Conv2D(
|
||||||
|
filters=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
use_bias=bias,
|
||||||
|
dilation_rate=dilation,
|
||||||
|
name="conv",
|
||||||
|
)
|
||||||
|
self.bn = tf.keras.layers.BatchNormalization(name="bn")
|
||||||
|
self.activation = tf.nn.relu
|
||||||
|
|
||||||
|
def call(self, input: tf.Tensor) -> tf.Tensor:
|
||||||
|
output = self.conv(input)
|
||||||
|
output = self.bn(output)
|
||||||
|
output = self.activation(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from:
|
||||||
|
# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb
|
||||||
|
class TFAdaptiveAvgPool1D(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, output_dim, mode="dense", **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.mode = mode
|
||||||
|
self.map = None
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
super().build(input_shape)
|
||||||
|
"""We pre-compute the sparse matrix for the build() step once. The below code comes
|
||||||
|
from https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993."""
|
||||||
|
|
||||||
|
def get_kernels(ind, outd) -> List:
|
||||||
|
"""Returns a List [(kernel_offset_start,kernel_length)] defining all the pooling kernels for a 1-D adaptive
|
||||||
|
pooling layer that takes an input of dimension `ind` and yields an output of dimension `outd`"""
|
||||||
|
|
||||||
|
def start_index(a, b, c):
|
||||||
|
return math.floor((float(a) * float(c)) / b)
|
||||||
|
|
||||||
|
def end_index(a, b, c):
|
||||||
|
return math.ceil((float(a + 1) * float(c)) / b)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ow in range(outd):
|
||||||
|
start = start_index(ow, outd, ind)
|
||||||
|
end = end_index(ow, outd, ind)
|
||||||
|
sz = end - start
|
||||||
|
results.append((start, sz))
|
||||||
|
return results
|
||||||
|
|
||||||
|
in_dim = int(input_shape[-1])
|
||||||
|
kernels = get_kernels(in_dim, self.output_dim)
|
||||||
|
sparse_map = np.zeros((in_dim, self.output_dim), dtype=np.float32)
|
||||||
|
for i, kernel in enumerate(kernels):
|
||||||
|
sparse_map[kernel[0] : kernel[0] + kernel[1], i] = 1 / kernel[1]
|
||||||
|
if self.mode == "dense":
|
||||||
|
self.map = tf.constant(sparse_map)
|
||||||
|
else:
|
||||||
|
self.map = tf.sparse.from_dense(sparse_map)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
if self.mode == "dense":
|
||||||
|
return inputs @ self.map
|
||||||
|
else:
|
||||||
|
input_dims = inputs.shape
|
||||||
|
input_matrix = tf.reshape(inputs, (-1, input_dims[-1]))
|
||||||
|
out = tf.sparse.sparse_dense_matmul(input_matrix, self.map)
|
||||||
|
return tf.reshape(out, input_dims[:-1].as_list() + [-1])
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super().get_config()
|
||||||
|
config.update({"output_dim": self.output_dim, "mode": self.mode})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class TFAdaptiveAvgPool2D(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, output_shape, mode="dense", **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.mode = mode
|
||||||
|
self.h_pool = TFAdaptiveAvgPool1D(output_shape[0], mode=mode, name="h_pool")
|
||||||
|
self.w_pool = TFAdaptiveAvgPool1D(output_shape[1], mode=mode, name="w_pool")
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
# Rearrange from NHWC -> NCHW
|
||||||
|
inputs = tf.transpose(inputs, perm=[0, 3, 1, 2])
|
||||||
|
# Perform W-pooling
|
||||||
|
inputs = self.w_pool(inputs)
|
||||||
|
# Rearrange NCHW -> NCWH
|
||||||
|
inputs = tf.transpose(inputs, perm=[0, 1, 3, 2])
|
||||||
|
# Perform H-pooling
|
||||||
|
inputs = self.h_pool(inputs)
|
||||||
|
# Rearrange from NCWH -> NHWC
|
||||||
|
inputs = tf.transpose(inputs, perm=[0, 3, 2, 1])
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super().get_config()
|
||||||
|
config.update({"mode": self.mode})
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
Pyramid Pooling Module (PPM) used in PSPNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||||
|
Module.
|
||||||
|
channels (int): Channels after modules, before conv_seg.
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pool_scales: Tuple[int, ...], channels: int, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.pool_scales = pool_scales
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
self.layer_list = []
|
||||||
|
for idx, pool_scale in enumerate(pool_scales):
|
||||||
|
pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
|
||||||
|
self.layer_list.append(
|
||||||
|
[
|
||||||
|
TFAdaptiveAvgPool2D(output_shape=pool_scale),
|
||||||
|
TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"{idx}.1"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, x: tf.Tensor) -> List[tf.Tensor]:
|
||||||
|
ppm_outs = []
|
||||||
|
inputs = x
|
||||||
|
|
||||||
|
for ppm in self.layer_list:
|
||||||
|
for layer_module in ppm:
|
||||||
|
ppm_out = layer_module(x)
|
||||||
|
x = ppm_out
|
||||||
|
|
||||||
|
upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
|
||||||
|
ppm_outs.append(upsampled_ppm_out)
|
||||||
|
return ppm_outs
|
||||||
|
|
||||||
|
|
||||||
|
class TFData2VecVisionUperHead(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
|
||||||
|
[UPerNet](https://arxiv.org/abs/1807.10221).
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
|
||||||
|
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
|
||||||
|
self.channels = config.hidden_size
|
||||||
|
self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
|
||||||
|
|
||||||
|
# PSP Module
|
||||||
|
self.psp_modules = TFData2VecVisionPyramidPoolingModule(self.pool_scales, self.channels, name="psp_modules")
|
||||||
|
self.bottleneck = TFData2VecVisionConvModule(self.channels, kernel_size=3, padding="same", name="bottleneck")
|
||||||
|
# FPN Module
|
||||||
|
self.lateral_convs = []
|
||||||
|
self.fpn_convs = []
|
||||||
|
for idx, _ in enumerate(self.in_channels[:-1]): # skip the top layer
|
||||||
|
l_conv = TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}")
|
||||||
|
fpn_conv = TFData2VecVisionConvModule(
|
||||||
|
out_channels=self.channels, kernel_size=3, padding="same", name=f"fpn_convs.{idx}"
|
||||||
|
)
|
||||||
|
self.lateral_convs.append(l_conv)
|
||||||
|
self.fpn_convs.append(fpn_conv)
|
||||||
|
|
||||||
|
self.fpn_bottleneck = TFData2VecVisionConvModule(
|
||||||
|
out_channels=self.channels, kernel_size=3, padding="same", name="fpn_bottleneck"
|
||||||
|
)
|
||||||
|
|
||||||
|
def psp_forward(self, inputs):
|
||||||
|
x = inputs[-1]
|
||||||
|
psp_outs = [x]
|
||||||
|
psp_outs.extend(self.psp_modules(x))
|
||||||
|
psp_outs = tf.concat(psp_outs, axis=-1)
|
||||||
|
output = self.bottleneck(psp_outs)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
|
||||||
|
# build laterals
|
||||||
|
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
|
||||||
|
|
||||||
|
laterals.append(self.psp_forward(encoder_hidden_states))
|
||||||
|
|
||||||
|
# build top-down path
|
||||||
|
used_backbone_levels = len(laterals)
|
||||||
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
|
prev_shape = shape_list(laterals[i - 1])[1:-1]
|
||||||
|
laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
|
||||||
|
|
||||||
|
# build outputs
|
||||||
|
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
|
||||||
|
# append psp feature
|
||||||
|
fpn_outs.append(laterals[-1])
|
||||||
|
|
||||||
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
|
fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
|
||||||
|
fpn_outs = tf.concat(fpn_outs, axis=-1)
|
||||||
|
output = self.fpn_bottleneck(fpn_outs)
|
||||||
|
output = self.classifier(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
Fully Convolution Networks for Semantic Segmentation. This head is implemented from
|
||||||
|
[FCNNet](https://arxiv.org/abs/1411.4038).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Data2VecVisionConfig): Configuration.
|
||||||
|
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||||
|
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||||
|
|
||||||
|
|
||||||
|
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Data2VecVisionConfig,
|
||||||
|
in_index: int = 2,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: Union[int, Tuple[int, int]] = 1,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.in_channels = config.hidden_size
|
||||||
|
self.channels = config.auxiliary_channels
|
||||||
|
self.num_convs = config.auxiliary_num_convs
|
||||||
|
self.concat_input = config.auxiliary_concat_input
|
||||||
|
self.in_index = in_index
|
||||||
|
|
||||||
|
convs = []
|
||||||
|
convs.append(
|
||||||
|
TFData2VecVisionConvModule(
|
||||||
|
out_channels=self.channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding="same",
|
||||||
|
dilation=dilation,
|
||||||
|
name="convs.0",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(self.num_convs - 1):
|
||||||
|
convs.append(
|
||||||
|
TFData2VecVisionConvModule(
|
||||||
|
out_channels=self.channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding="same",
|
||||||
|
dilation=dilation,
|
||||||
|
name=f"conv_module_{i+2}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.num_convs == 0:
|
||||||
|
self.convs = [tf.identity]
|
||||||
|
else:
|
||||||
|
self.convs = convs
|
||||||
|
if self.concat_input:
|
||||||
|
self.conv_cat = TFData2VecVisionConvModule(
|
||||||
|
out_channels=self.channels, kernel_size=kernel_size, padding="same", name="conv_cat"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
|
||||||
|
|
||||||
|
def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
|
||||||
|
# just take the relevant feature maps
|
||||||
|
hidden_states = encoder_hidden_states[self.in_index]
|
||||||
|
output = hidden_states
|
||||||
|
for layer_module in self.convs:
|
||||||
|
output = layer_module(output)
|
||||||
|
if self.concat_input:
|
||||||
|
output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
|
||||||
|
output = self.classifier(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
|
||||||
|
""",
|
||||||
|
DATA2VEC_VISION_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
|
||||||
|
def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
|
||||||
|
|
||||||
|
# FPNs
|
||||||
|
self.fpn1 = [
|
||||||
|
tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
|
||||||
|
tf.keras.layers.BatchNormalization(name="fpn1.1"),
|
||||||
|
tf.keras.layers.Activation("gelu"),
|
||||||
|
tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
|
||||||
|
]
|
||||||
|
self.fpn2 = [tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
|
||||||
|
|
||||||
|
self.fpn3 = tf.identity
|
||||||
|
self.fpn4 = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)
|
||||||
|
|
||||||
|
# Semantic segmentation head(s)
|
||||||
|
self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
|
||||||
|
self.auxiliary_head = (
|
||||||
|
TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(self, logits, auxiliary_logits, labels):
|
||||||
|
# upsample logits to the images' original size
|
||||||
|
if len(shape_list(labels)) > 3:
|
||||||
|
label_interp_shape = shape_list(labels)[1:-1]
|
||||||
|
else:
|
||||||
|
label_interp_shape = shape_list(labels)[-2:]
|
||||||
|
|
||||||
|
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
|
||||||
|
if auxiliary_logits is not None:
|
||||||
|
upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
|
||||||
|
# compute weighted loss
|
||||||
|
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
|
||||||
|
|
||||||
|
# Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
|
||||||
|
# Utility to mask the index to ignore during computing the loss.
|
||||||
|
def masked_loss(real, pred):
|
||||||
|
mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
|
||||||
|
loss_ = loss_fct(real, pred)
|
||||||
|
mask = tf.cast(mask, dtype=loss_.dtype)
|
||||||
|
loss_ *= mask
|
||||||
|
|
||||||
|
return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
|
||||||
|
|
||||||
|
main_loss = masked_loss(labels, upsampled_logits)
|
||||||
|
auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
|
||||||
|
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[tf.Tensor] = None,
|
||||||
|
head_mask: Optional[tf.Tensor] = None,
|
||||||
|
labels: Optional[tf.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[tuple, TFSemanticSegmenterOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
|
||||||
|
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||||
|
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoFeatureExtractor, TFData2VecVisionForSemanticSegmentation
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/data2vec-vision-base")
|
||||||
|
>>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
|
||||||
|
|
||||||
|
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> # logits are of shape (batch_size, num_labels, height, width)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.data2vec_vision(
|
||||||
|
pixel_values,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True, # we need the intermediate hidden states
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||||||
|
|
||||||
|
# only keep certain features, and reshape
|
||||||
|
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
|
||||||
|
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
|
||||||
|
batch_size = shape_list(pixel_values)[0]
|
||||||
|
patch_resolution = self.config.image_size // self.config.patch_size
|
||||||
|
|
||||||
|
def reshape_features(x):
|
||||||
|
x = tf.reshape(x, (batch_size, patch_resolution, patch_resolution, -1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
features = [reshape_features(x[:, 1:, :]) for x in features]
|
||||||
|
|
||||||
|
# apply FPNs
|
||||||
|
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
||||||
|
for module in ops[0]:
|
||||||
|
features[0] = module(features[0])
|
||||||
|
features[1] = ops[1][0](features[1])
|
||||||
|
for i in range(len(features[2:])):
|
||||||
|
features[i + 2] = ops[i + 2](features[i + 2])
|
||||||
|
|
||||||
|
logits = self.decode_head(features)
|
||||||
|
# Tranpose the logits to maintain consistency in the output formats.
|
||||||
|
transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
|
||||||
|
|
||||||
|
auxiliary_logits = None
|
||||||
|
if self.auxiliary_head is not None:
|
||||||
|
auxiliary_logits = self.auxiliary_head(features)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
if self.config.num_labels == 1:
|
||||||
|
raise ValueError("The number of labels should be greater than one")
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
if output_hidden_states:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
else:
|
||||||
|
output = (logits,) + outputs[2:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return TFSemanticSegmenterOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=transposed_logits,
|
||||||
|
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -759,6 +759,13 @@ class TFData2VecVisionForImageClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFData2VecVisionForSemanticSegmentation(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFData2VecVisionModel(metaclass=DummyObject):
|
class TFData2VecVisionModel(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
|||||||
@@ -389,6 +389,10 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
check_hidden_states_output(inputs_dict, config, model_class)
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||||
|
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||||
|
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||||
|
|
||||||
def test_for_image_classification(self):
|
def test_for_image_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||||
|
|||||||
@@ -31,7 +31,11 @@ from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_te
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel
|
from transformers import (
|
||||||
|
TFData2VecVisionForImageClassification,
|
||||||
|
TFData2VecVisionForSemanticSegmentation,
|
||||||
|
TFData2VecVisionModel,
|
||||||
|
)
|
||||||
from transformers.models.data2vec.modeling_tf_data2vec_vision import (
|
from transformers.models.data2vec.modeling_tf_data2vec_vision import (
|
||||||
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
@@ -142,6 +146,18 @@ class TFData2VecVisionModelTester:
|
|||||||
result = model(pixel_values, labels=labels, training=False)
|
result = model(pixel_values, labels=labels, training=False)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||||
|
|
||||||
|
def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = TFData2VecVisionForSemanticSegmentation(config)
|
||||||
|
result = model(pixel_values, training=False)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
|
||||||
|
)
|
||||||
|
result = model(pixel_values, labels=pixel_labels)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
config, pixel_values, labels, pixel_labels = config_and_inputs
|
config, pixel_values, labels, pixel_labels = config_and_inputs
|
||||||
@@ -162,7 +178,11 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
attention_mask and seq_length.
|
attention_mask and seq_length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (TFData2VecVisionModel, TFData2VecVisionForImageClassification) if is_tf_available() else ()
|
all_model_classes = (
|
||||||
|
(TFData2VecVisionModel, TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation)
|
||||||
|
if is_tf_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_onnx = False
|
test_onnx = False
|
||||||
@@ -208,6 +228,14 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_image_segmentation(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip("Test was written for TF 1.x and isn't really relevant here")
|
||||||
|
def test_compile_tf_model(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
@@ -354,6 +382,10 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
val_loss2 = history2.history["val_loss"][0]
|
val_loss2 = history2.history["val_loss"][0]
|
||||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||||
|
|
||||||
|
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||||
|
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||||
|
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||||
|
|
||||||
# Overriding this method since the base method won't be compatible with Data2VecVision.
|
# Overriding this method since the base method won't be compatible with Data2VecVision.
|
||||||
def test_loss_computation(self):
|
def test_loss_computation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user