From 900677487d8953fa1fabdc3e298abe63ff09ece8 Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Tue, 4 Apr 2023 22:11:12 +0530 Subject: [PATCH] Flax Regnet (#21867) * initial commit * review changes * post model PR merge * updating doc --- docs/source/de/index.mdx | 2 +- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/regnet.mdx | 14 +- docs/source/es/index.mdx | 2 +- docs/source/fr/index.mdx | 2 +- docs/source/it/index.mdx | 2 +- docs/source/ja/index.mdx | 2 +- docs/source/ko/index.mdx | 2 +- docs/source/pt/index.mdx | 2 +- docs/source/zh/index.mdx | 2 +- src/transformers/__init__.py | 4 + .../models/auto/modeling_flax_auto.py | 2 + src/transformers/models/regnet/__init__.py | 32 +- .../models/regnet/modeling_flax_regnet.py | 818 ++++++++++++++++++ .../models/resnet/modeling_flax_resnet.py | 2 +- src/transformers/utils/dummy_flax_objects.py | 21 + .../regnet/test_modeling_flax_regnet.py | 237 +++++ 17 files changed, 1136 insertions(+), 12 deletions(-) create mode 100644 src/transformers/models/regnet/modeling_flax_regnet.py create mode 100644 tests/models/regnet/test_modeling_flax_regnet.py diff --git a/docs/source/de/index.mdx b/docs/source/de/index.mdx index 5983e47bfd..c14e803ed0 100644 --- a/docs/source/de/index.mdx +++ b/docs/source/de/index.mdx @@ -283,7 +283,7 @@ Flax), PyTorch, und/oder TensorFlow haben. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 8c73cc98dc..cb936d7336 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/regnet.mdx b/docs/source/en/model_doc/regnet.mdx index 1557ab24df..ac94b5b172 100644 --- a/docs/source/en/model_doc/regnet.mdx +++ b/docs/source/en/model_doc/regnet.mdx @@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f ## TFRegNetForImageClassification [[autodoc]] TFRegNetForImageClassification - - call \ No newline at end of file + - call + + +## FlaxRegNetModel + +[[autodoc]] FlaxRegNetModel + - __call__ + + +## FlaxRegNetForImageClassification + +[[autodoc]] FlaxRegNetForImageClassification + - __call__ \ No newline at end of file diff --git a/docs/source/es/index.mdx b/docs/source/es/index.mdx index 11943daa60..49a4f83053 100644 --- a/docs/source/es/index.mdx +++ b/docs/source/es/index.mdx @@ -235,7 +235,7 @@ Flax), PyTorch y/o TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/fr/index.mdx b/docs/source/fr/index.mdx index 8184d1ced8..63a3e8391f 100644 --- a/docs/source/fr/index.mdx +++ b/docs/source/fr/index.mdx @@ -347,7 +347,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx index 4ede7f0a9f..4c050bfe52 100644 --- a/docs/source/it/index.mdx +++ b/docs/source/it/index.mdx @@ -252,7 +252,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗 | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/ja/index.mdx b/docs/source/ja/index.mdx index cca8f52be3..f55a3fd42a 100644 --- a/docs/source/ja/index.mdx +++ b/docs/source/ja/index.mdx @@ -337,7 +337,7 @@ specific language governing permissions and limitations under the License. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/ko/index.mdx b/docs/source/ko/index.mdx index 4cf6863df0..5a6428d694 100644 --- a/docs/source/ko/index.mdx +++ b/docs/source/ko/index.mdx @@ -306,7 +306,7 @@ specific language governing permissions and limitations under the License. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/pt/index.mdx b/docs/source/pt/index.mdx index d20d746a2a..9b5cbc12e6 100644 --- a/docs/source/pt/index.mdx +++ b/docs/source/pt/index.mdx @@ -250,7 +250,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/zh/index.mdx b/docs/source/zh/index.mdx index 33b9181b71..71f5d7e3b1 100644 --- a/docs/source/zh/index.mdx +++ b/docs/source/zh/index.mdx @@ -336,7 +336,7 @@ Flax), PyTorch, 和/或者 TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | -| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0fa10407c4..d6208e69eb 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3661,6 +3661,9 @@ else: "FlaxPegasusPreTrainedModel", ] ) + _import_structure["models.regnet"].extend( + ["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"] + ) _import_structure["models.resnet"].extend( ["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] ) @@ -6739,6 +6742,7 @@ if TYPE_CHECKING: from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel + from .models.regnet import FlaxRegNetForImageClassification, FlaxRegNetModel, FlaxRegNetPreTrainedModel from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel from .models.roberta import ( FlaxRobertaForCausalLM, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 139533939d..755d1f07a3 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ("mt5", "FlaxMT5Model"), ("opt", "FlaxOPTModel"), ("pegasus", "FlaxPegasusModel"), + ("regnet", "FlaxRegNetModel"), ("resnet", "FlaxResNetModel"), ("roberta", "FlaxRobertaModel"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), @@ -120,6 +121,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image-classsification ("beit", "FlaxBeitForImageClassification"), + ("regnet", "FlaxRegNetForImageClassification"), ("resnet", "FlaxResNetForImageClassification"), ("vit", "FlaxViTForImageClassification"), ] diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py index 91221e9012..5084c44860 100644 --- a/src/transformers/models/regnet/__init__.py +++ b/src/transformers/models/regnet/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_torch_available, +) _import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]} @@ -44,6 +50,18 @@ else: "TFRegNetPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_regnet"] = [ + "FlaxRegNetForImageClassification", + "FlaxRegNetModel", + "FlaxRegNetPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig @@ -74,6 +92,18 @@ if TYPE_CHECKING: TFRegNetPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_regnet import ( + FlaxRegNetForImageClassification, + FlaxRegNetModel, + FlaxRegNetPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/regnet/modeling_flax_regnet.py b/src/transformers/models/regnet/modeling_flax_regnet.py new file mode 100644 index 0000000000..9fef1868d6 --- /dev/null +++ b/src/transformers/models/regnet/modeling_flax_regnet.py @@ -0,0 +1,818 @@ +# coding=utf-8 +# Copyright 2023 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from transformers import RegNetConfig +from transformers.modeling_flax_outputs import ( + FlaxBaseModelOutputWithNoAttention, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndNoAttention, + FlaxImageClassifierOutputWithNoAttention, +) +from transformers.modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, +) + + +REGNET_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`RegNetImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.resnet.modeling_flax_resnet.Identity +class Identity(nn.Module): + """Identity function.""" + + @nn.compact + def __call__(self, x, **kwargs): + return x + + +class FlaxRegNetConvLayer(nn.Module): + out_channels: int + kernel_size: int = 3 + stride: int = 1 + groups: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(self.kernel_size, self.kernel_size), + strides=self.stride, + padding=self.kernel_size // 2, + feature_group_count=self.groups, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetEmbeddings(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedder = FlaxRegNetConvLayer( + self.config.embedding_size, + kernel_size=3, + stride=2, + activation=self.config.hidden_act, + dtype=self.dtype, + ) + + def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + num_channels = pixel_values.shape[-1] + if num_channels != self.config.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = self.embedder(pixel_values, deterministic=deterministic) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetShortCut with ResNet->RegNet +class FlaxRegNetShortCut(nn.Module): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + out_channels: int + stride: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(1, 1), + strides=self.stride, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + return hidden_state + + +class FlaxRegNetSELayerCollection(nn.Module): + in_channels: int + reduced_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_1 = nn.Conv( + self.reduced_channels, + kernel_size=(1, 1), + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + name="0", + ) # 0 is the name used in corresponding pytorch implementation + self.conv_2 = nn.Conv( + self.in_channels, + kernel_size=(1, 1), + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + name="2", + ) # 2 is the name used in corresponding pytorch implementation + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + hidden_state = self.conv_1(hidden_state) + hidden_state = nn.relu(hidden_state) + hidden_state = self.conv_2(hidden_state) + attention = nn.sigmoid(hidden_state) + + return attention + + +class FlaxRegNetSELayer(nn.Module): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + in_channels: int + reduced_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0))) + self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype) + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + pooled = self.pooler( + hidden_state, + window_shape=(hidden_state.shape[1], hidden_state.shape[2]), + strides=(hidden_state.shape[1], hidden_state.shape[2]), + ) + attention = self.attention(pooled) + hidden_state = hidden_state * attention + return hidden_state + + +class FlaxRegNetXLayerCollection(nn.Module): + config: RegNetConfig + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + groups = max(1, self.out_channels // self.config.groups_width) + + self.layer = [ + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + FlaxRegNetConvLayer( + self.out_channels, + stride=self.stride, + groups=groups, + activation=self.config.hidden_act, + dtype=self.dtype, + name="1", + ), + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=None, + dtype=self.dtype, + name="2", + ), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxRegNetXLayer(nn.Module): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxRegNetShortCut( + self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + if should_apply_shortcut + else Identity() + ) + self.layer = FlaxRegNetXLayerCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetYLayerCollection(nn.Module): + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + groups = max(1, self.out_channels // self.config.groups_width) + + self.layer = [ + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + FlaxRegNetConvLayer( + self.out_channels, + stride=self.stride, + groups=groups, + activation=self.config.hidden_act, + dtype=self.dtype, + name="1", + ), + FlaxRegNetSELayer( + self.out_channels, + reduced_channels=int(round(self.in_channels / 4)), + dtype=self.dtype, + name="2", + ), + FlaxRegNetConvLayer( + self.out_channels, + kernel_size=1, + activation=None, + dtype=self.dtype, + name="3", + ), + ] + + def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state) + return hidden_state + + +class FlaxRegNetYLayer(nn.Module): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + + self.shortcut = ( + FlaxRegNetShortCut( + self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + if should_apply_shortcut + else Identity() + ) + self.layer = FlaxRegNetYLayerCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxRegNetStageLayersCollection(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer + + layers = [ + # downsampling is done in the first layer with stride of 2 + layer( + self.config, + self.in_channels, + self.out_channels, + stride=self.stride, + dtype=self.dtype, + name="0", + ) + ] + + for i in range(self.depth - 1): + layers.append( + layer( + self.config, + self.out_channels, + self.out_channels, + dtype=self.dtype, + name=str(i + 1), + ) + ) + + self.layers = layers + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = x + for layer in self.layers: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet +class FlaxRegNetStage(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + config: RegNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = FlaxRegNetStageLayersCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + depth=self.depth, + dtype=self.dtype, + ) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + return self.layers(x, deterministic=deterministic) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet +class FlaxRegNetStageCollection(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) + stages = [ + FlaxRegNetStage( + self.config, + self.config.embedding_size, + self.config.hidden_sizes[0], + stride=2 if self.config.downsample_in_first_stage else 1, + depth=self.config.depths[0], + dtype=self.dtype, + name="0", + ) + ] + + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): + stages.append( + FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) + ) + + self.stages = stages + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + hidden_state = stage_module(hidden_state, deterministic=deterministic) + + return hidden_state, hidden_states + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder with ResNet->RegNet +class FlaxRegNetEncoder(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_state, hidden_states = self.stages( + hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic + ) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return FlaxBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel with ResNet->RegNet,resnet->regnet,RESNET->REGNET +class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: RegNetConfig, + input_shape=(1, 224, 224, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + rngs = {"params": rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False) + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + params: dict = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + # Handle any PRNG if needed + rngs = {} + + return self.module.apply( + { + "params": params["params"] if params is not None else self.params["params"], + "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], + }, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True + ) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule with ResNet->RegNet +class FlaxRegNetModule(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype) + + # Adaptive average pooling used in resnet + self.pooler = partial( + nn.avg_pool, + padding=((0, 0), (0, 0)), + ) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler( + last_hidden_state, + window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + ).transpose(0, 3, 1, 2) + + last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +class FlaxRegNetModel(FlaxRegNetPreTrainedModel): + module_class = FlaxRegNetModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxRegNetModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") + >>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxRegNetModel, + output_type=FlaxBaseModelOutputWithPooling, + config_class=RegNetConfig, +) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection with ResNet->RegNet +class FlaxRegNetClassifierCollection(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.classifier(x) + + +# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule with ResNet->RegNet,resnet->regnet,RESNET->REGNET +class FlaxRegNetForImageClassificationModule(nn.Module): + config: RegNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype) + + if self.config.num_labels > 0: + self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype) + else: + self.classifier = Identity() + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet( + pixel_values, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output[:, :, 0, 0]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel): + module_class = FlaxRegNetForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040") + >>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxRegNetForImageClassification, + output_type=FlaxImageClassifierOutputWithNoAttention, + config_class=RegNetConfig, +) diff --git a/src/transformers/models/resnet/modeling_flax_resnet.py b/src/transformers/models/resnet/modeling_flax_resnet.py index 36b2869607..875716d3f5 100644 --- a/src/transformers/models/resnet/modeling_flax_resnet.py +++ b/src/transformers/models/resnet/modeling_flax_resnet.py @@ -89,7 +89,7 @@ class Identity(nn.Module): """Identity function.""" @nn.compact - def __call__(self, x): + def __call__(self, x, **kwargs): return x diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 2bba612508..eeec327749 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxRegNetForImageClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRegNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRegNetPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxResNetForImageClassification(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/regnet/test_modeling_flax_regnet.py b/tests/models/regnet/test_modeling_flax_regnet.py new file mode 100644 index 0000000000..e9788ab09d --- /dev/null +++ b/tests/models/regnet/test_modeling_flax_regnet.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +from transformers import RegNetConfig, is_flax_available +from transformers.testing_utils import require_flax, slow +from transformers.utils import cached_property, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + from transformers.models.regnet.modeling_flax_regnet import FlaxRegNetForImageClassification, FlaxRegNetModel + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class FlaxRegNetModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return RegNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + image_size=self.image_size, + ) + + def create_and_check_model(self, config, pixel_values): + model = FlaxRegNetModel(config=config) + result = model(pixel_values) + + # Output shape (b, c, h, w) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values): + config.num_labels = self.num_labels + model = FlaxRegNetForImageClassification(config=config) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_flax +class FlaxResNetModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = (FlaxRegNetModel, FlaxRegNetForImageClassification) if is_flax_available() else () + + is_encoder_decoder = False + test_head_masking = False + has_attentions = False + + def setUp(self) -> None: + self.model_tester = FlaxRegNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @unittest.skip(reason="RegNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="RegNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(pixel_values, **kwargs): + return model(pixel_values=pixel_values, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_flax +class FlaxRegNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return AutoFeatureExtractor.from_pretrained("facebook/regnet-y-040") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040") + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="np") + + outputs = model(**inputs) + + # verify the logits + expected_shape = (1, 1000) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = jnp.array([-0.4180, -1.5051, -3.4836]) + + self.assertTrue(jnp.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))