From a0cbbba31f498dac4bf92af350d48d968d048c44 Mon Sep 17 00:00:00 2001 From: Shubhamai Date: Sat, 25 Mar 2023 01:15:57 +0530 Subject: [PATCH] Resnet flax (#21472) * [WIP] flax resnet * added pretrained flax models, results reproducible * Added pretrained flax models, results reproducible * working on tests * no real code change, just some comments * [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * removing loss from modeling flax output class * fixing classifier tests * fixing comments, model output * cleaning comments * review changes * review changes * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * renaming Flax to PyTorch --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/de/index.mdx | 2 +- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/resnet.mdx | 10 + docs/source/es/index.mdx | 2 +- docs/source/it/index.mdx | 2 +- docs/source/ja/index.mdx | 2 +- docs/source/ko/index.mdx | 2 +- docs/source/pt/index.mdx | 2 +- src/transformers/__init__.py | 4 + src/transformers/modeling_flax_outputs.py | 58 ++ .../models/auto/modeling_flax_auto.py | 2 + src/transformers/models/resnet/__init__.py | 27 +- .../models/resnet/modeling_flax_resnet.py | 701 ++++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 21 + .../resnet/test_modeling_flax_resnet.py | 228 ++++++ 15 files changed, 1057 insertions(+), 8 deletions(-) create mode 100644 src/transformers/models/resnet/modeling_flax_resnet.py create mode 100644 tests/models/resnet/test_modeling_flax_resnet.py diff --git a/docs/source/de/index.mdx b/docs/source/de/index.mdx index c1340820b5..5983e47bfd 100644 --- a/docs/source/de/index.mdx +++ b/docs/source/de/index.mdx @@ -285,7 +285,7 @@ Flax), PyTorch, und/oder TensorFlow haben. | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index f67483f84f..198f3d97de 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow. | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/resnet.mdx b/docs/source/en/model_doc/resnet.mdx index 476698e9ab..a34596bdd6 100644 --- a/docs/source/en/model_doc/resnet.mdx +++ b/docs/source/en/model_doc/resnet.mdx @@ -71,3 +71,13 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] TFResNetForImageClassification - call + +## FlaxResNetModel + +[[autodoc]] FlaxResNetModel + - __call__ + +## FlaxResNetForImageClassification + +[[autodoc]] FlaxResNetForImageClassification + - __call__ diff --git a/docs/source/es/index.mdx b/docs/source/es/index.mdx index baedf45a0f..11943daa60 100644 --- a/docs/source/es/index.mdx +++ b/docs/source/es/index.mdx @@ -237,7 +237,7 @@ Flax), PyTorch y/o TensorFlow. | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/it/index.mdx b/docs/source/it/index.mdx index a6d474bb6c..4ede7f0a9f 100644 --- a/docs/source/it/index.mdx +++ b/docs/source/it/index.mdx @@ -254,7 +254,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗 | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/ja/index.mdx b/docs/source/ja/index.mdx index a644a885d3..cca8f52be3 100644 --- a/docs/source/ja/index.mdx +++ b/docs/source/ja/index.mdx @@ -339,7 +339,7 @@ specific language governing permissions and limitations under the License. | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ | diff --git a/docs/source/ko/index.mdx b/docs/source/ko/index.mdx index 789aa41a28..4cf6863df0 100644 --- a/docs/source/ko/index.mdx +++ b/docs/source/ko/index.mdx @@ -308,7 +308,7 @@ specific language governing permissions and limitations under the License. | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/pt/index.mdx b/docs/source/pt/index.mdx index 8e74a1feb4..d20d746a2a 100644 --- a/docs/source/pt/index.mdx +++ b/docs/source/pt/index.mdx @@ -252,7 +252,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d1258bbc6a..584c83783b 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3637,6 +3637,9 @@ else: "FlaxPegasusPreTrainedModel", ] ) + _import_structure["models.resnet"].extend( + ["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] + ) _import_structure["models.roberta"].extend( [ "FlaxRobertaForCausalLM", @@ -6692,6 +6695,7 @@ if TYPE_CHECKING: from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel + from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel from .models.roberta import ( FlaxRobertaForCausalLM, FlaxRobertaForMaskedLM, diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index 4f6cc5a901..179a0b7879 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -45,6 +45,64 @@ class FlaxBaseModelOutput(ModelOutput): attentions: Optional[Tuple[jnp.ndarray]] = None +@flax.struct.dataclass +class FlaxBaseModelOutputWithNoAttention(ModelOutput): + """ + Base class for model's outputs, with potential hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a pooling operation on the spatial dimensions. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the + model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxImageClassifierOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also + called feature maps) of the model at the output of each stage. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + + @flax.struct.dataclass class FlaxBaseModelOutputWithPast(ModelOutput): """ diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 77be9b33f0..139533939d 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ("mt5", "FlaxMT5Model"), ("opt", "FlaxOPTModel"), ("pegasus", "FlaxPegasusModel"), + ("resnet", "FlaxResNetModel"), ("roberta", "FlaxRobertaModel"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), ("roformer", "FlaxRoFormerModel"), @@ -119,6 +120,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image-classsification ("beit", "FlaxBeitForImageClassification"), + ("resnet", "FlaxResNetForImageClassification"), ("vit", "FlaxViTForImageClassification"), ] ) diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py index 2baa71a277..62e6b1c2ca 100644 --- a/src/transformers/models/resnet/__init__.py +++ b/src/transformers/models/resnet/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_torch_available, +) _import_structure = { @@ -47,6 +53,17 @@ else: "TFResNetPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_resnet"] = [ + "FlaxResNetForImageClassification", + "FlaxResNetModel", + "FlaxResNetPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig @@ -78,6 +95,14 @@ if TYPE_CHECKING: TFResNetPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel + else: import sys diff --git a/src/transformers/models/resnet/modeling_flax_resnet.py b/src/transformers/models/resnet/modeling_flax_resnet.py new file mode 100644 index 0000000000..36b2869607 --- /dev/null +++ b/src/transformers/models/resnet/modeling_flax_resnet.py @@ -0,0 +1,701 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithNoAttention, + FlaxBaseModelOutputWithPoolingAndNoAttention, + FlaxImageClassifierOutputWithNoAttention, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_resnet import ResNetConfig + + +RESNET_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`AutoImageProcessor.__call__`] for details. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Identity(nn.Module): + """Identity function.""" + + @nn.compact + def __call__(self, x): + return x + + +class FlaxResNetConvLayer(nn.Module): + out_channels: int + kernel_size: int = 3 + stride: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(self.kernel_size, self.kernel_size), + strides=self.stride, + padding=self.kernel_size // 2, + dtype=self.dtype, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype), + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity() + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetEmbeddings(nn.Module): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.embedder = FlaxResNetConvLayer( + self.config.embedding_size, + kernel_size=7, + stride=2, + activation=self.config.hidden_act, + dtype=self.dtype, + ) + + self.max_pool = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))) + + def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + num_channels = pixel_values.shape[-1] + if num_channels != self.config.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values, deterministic=deterministic) + embedding = self.max_pool(embedding) + return embedding + + +class FlaxResNetShortCut(nn.Module): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + out_channels: int + stride: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.convolution = nn.Conv( + self.out_channels, + kernel_size=(1, 1), + strides=self.stride, + use_bias=False, + kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"), + dtype=self.dtype, + ) + self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state, use_running_average=deterministic) + return hidden_state + + +class FlaxResNetBasicLayerCollection(nn.Module): + out_channels: int + stride: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layer = [ + FlaxResNetConvLayer(self.out_channels, stride=self.stride, dtype=self.dtype), + FlaxResNetConvLayer(self.out_channels, activation=None, dtype=self.dtype), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + in_channels: int + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) + if should_apply_shortcut + else None + ) + self.layer = FlaxResNetBasicLayerCollection( + out_channels=self.out_channels, + stride=self.stride, + activation=self.activation, + dtype=self.dtype, + ) + self.activation_func = ACT2FN[self.activation] + + def __call__(self, hidden_state, deterministic: bool = True): + residual = hidden_state + hidden_state = self.layer(hidden_state, deterministic=deterministic) + + if self.shortcut is not None: + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state += residual + + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetBottleNeckLayerCollection(nn.Module): + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + reduction: int = 4 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + reduces_channels = self.out_channels // self.reduction + + self.layer = [ + FlaxResNetConvLayer(reduces_channels, kernel_size=1, dtype=self.dtype, name="0"), + FlaxResNetConvLayer(reduces_channels, stride=self.stride, dtype=self.dtype, name="1"), + FlaxResNetConvLayer(self.out_channels, kernel_size=1, activation=None, dtype=self.dtype, name="2"), + ] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + for layer in self.layer: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetBottleNeckLayer(nn.Module): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the + input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution + remaps the reduced features to `out_channels`. + """ + + in_channels: int + out_channels: int + stride: int = 1 + activation: Optional[str] = "relu" + reduction: int = 4 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1 + self.shortcut = ( + FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype) + if should_apply_shortcut + else None + ) + + self.layer = FlaxResNetBottleNeckLayerCollection( + self.out_channels, + stride=self.stride, + activation=self.activation, + reduction=self.reduction, + dtype=self.dtype, + ) + + self.activation_func = ACT2FN[self.activation] + + def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + residual = hidden_state + + if self.shortcut is not None: + residual = self.shortcut(residual, deterministic=deterministic) + hidden_state = self.layer(hidden_state, deterministic) + hidden_state += residual + hidden_state = self.activation_func(hidden_state) + return hidden_state + + +class FlaxResNetStageLayersCollection(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + config: ResNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + layer = FlaxResNetBottleNeckLayer if self.config.layer_type == "bottleneck" else FlaxResNetBasicLayer + + layers = [ + # downsampling is done in the first layer with stride of 2 + layer( + self.in_channels, + self.out_channels, + stride=self.stride, + activation=self.config.hidden_act, + dtype=self.dtype, + name="0", + ), + ] + + for i in range(self.depth - 1): + layers.append( + layer( + self.out_channels, + self.out_channels, + activation=self.config.hidden_act, + dtype=self.dtype, + name=str(i + 1), + ) + ) + + self.layers = layers + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + hidden_state = x + for layer in self.layers: + hidden_state = layer(hidden_state, deterministic=deterministic) + return hidden_state + + +class FlaxResNetStage(nn.Module): + """ + A ResNet stage composed by stacked layers. + """ + + config: ResNetConfig + in_channels: int + out_channels: int + stride: int = 2 + depth: int = 2 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = FlaxResNetStageLayersCollection( + self.config, + in_channels=self.in_channels, + out_channels=self.out_channels, + stride=self.stride, + depth=self.depth, + dtype=self.dtype, + ) + + def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: + return self.layers(x, deterministic=deterministic) + + +class FlaxResNetStageCollection(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:]) + stages = [ + FlaxResNetStage( + self.config, + self.config.embedding_size, + self.config.hidden_sizes[0], + stride=2 if self.config.downsample_in_first_stage else 1, + depth=self.config.depths[0], + dtype=self.dtype, + name="0", + ) + ] + + for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])): + stages.append( + FlaxResNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1)) + ) + + self.stages = stages + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + hidden_state = stage_module(hidden_state, deterministic=deterministic) + + return hidden_state, hidden_states + + +class FlaxResNetEncoder(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.stages = FlaxResNetStageCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_state: jnp.ndarray, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> FlaxBaseModelOutputWithNoAttention: + hidden_state, hidden_states = self.stages( + hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic + ) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return FlaxBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class FlaxResNetPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: ResNetConfig, + input_shape=(1, 224, 224, 3), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + rngs = {"params": rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False) + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + def __call__( + self, + pixel_values, + params: dict = None, + train: bool = False, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + # Handle any PRNG if needed + rngs = {} + + return self.module.apply( + { + "params": params["params"] if params is not None else self.params["params"], + "batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"], + }, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_hidden_states, + return_dict, + rngs=rngs, + mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True + ) + + +class FlaxResNetModule(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedder = FlaxResNetEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxResNetEncoder(self.config, dtype=self.dtype) + + # Adaptive average pooling used in resnet + self.pooler = partial( + nn.avg_pool, + padding=((0, 0), (0, 0)), + ) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> FlaxBaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler( + last_hidden_state, + window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]), + ).transpose(0, 3, 1, 2) + + last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class FlaxResNetModel(FlaxResNetPreTrainedModel): + module_class = FlaxResNetModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxResNetModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50") + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxResNetModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxResNetModel, output_type=FlaxBaseModelOutputWithPoolingAndNoAttention, config_class=ResNetConfig +) + + +class FlaxResNetClassifierCollection(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1") + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.classifier(x) + + +class FlaxResNetForImageClassificationModule(nn.Module): + config: ResNetConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.resnet = FlaxResNetModule(config=self.config, dtype=self.dtype) + + if self.config.num_labels > 0: + self.classifier = FlaxResNetClassifierCollection(self.config, dtype=self.dtype) + else: + self.classifier = Identity() + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet( + pixel_values, + deterministic=deterministic, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output[:, :, 0, 0]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class FlaxResNetForImageClassification(FlaxResNetPreTrainedModel): + module_class = FlaxResNetForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + >>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxResNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxResNetForImageClassification, output_type=FlaxImageClassifierOutputWithNoAttention, config_class=ResNetConfig +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 60004790ec..2bba612508 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxResNetForImageClassification(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxResNetModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxResNetPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxRobertaForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/resnet/test_modeling_flax_resnet.py b/tests/models/resnet/test_modeling_flax_resnet.py new file mode 100644 index 0000000000..ee56cfe113 --- /dev/null +++ b/tests/models/resnet/test_modeling_flax_resnet.py @@ -0,0 +1,228 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +from transformers import ResNetConfig, is_flax_available +from transformers.testing_utils import require_flax, slow +from transformers.utils import cached_property, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + from transformers.models.resnet.modeling_flax_resnet import FlaxResNetForImageClassification, FlaxResNetModel + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class FlaxResNetModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return ResNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + image_size=self.image_size, + ) + + def create_and_check_model(self, config, pixel_values): + model = FlaxResNetModel(config=config) + result = model(pixel_values) + + # Output shape (b, c, h, w) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values): + config.num_labels = self.num_labels + model = FlaxResNetForImageClassification(config=config) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_flax +class FlaxResNetModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = (FlaxResNetModel, FlaxResNetForImageClassification) if is_flax_available() else () + + is_encoder_decoder = False + test_head_masking = False + has_attentions = False + + def setUp(self) -> None: + self.model_tester = FlaxResNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @unittest.skip(reason="ResNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="ResNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + @unittest.skip(reason="ResNet does not use feedforward chunking") + def test_feed_forward_chunking(self): + pass + + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(pixel_values, **kwargs): + return model(pixel_values=pixel_values, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_flax +class FlaxResNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50") + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="np") + + outputs = model(**inputs) + + # verify the logits + expected_shape = (1, 1000) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = jnp.array([-11.1069, -9.7877, -8.3777]) + + self.assertTrue(jnp.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))