From 519a677e87d2f72045567aed9853c020524935e0 Mon Sep 17 00:00:00 2001 From: lumliolum Date: Tue, 2 Nov 2021 22:59:14 +0530 Subject: [PATCH] Added Beit model output class (#14133) * add Beit model ouput class * inherting from BaseModelOuputWithPooling * updated docs if use_mean_pooling is False * added beit specific outputs in model docs * changed the import path * Fix docs Co-authored-by: Niels Rogge --- docs/source/model_doc/beit.rst | 11 +++++++ src/transformers/models/beit/modeling_beit.py | 31 +++++++++++++++++-- .../models/beit/modeling_flax_beit.py | 28 +++++++++++++++-- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/docs/source/model_doc/beit.rst b/docs/source/model_doc/beit.rst index af658cf60d..aac5d9f127 100644 --- a/docs/source/model_doc/beit.rst +++ b/docs/source/model_doc/beit.rst @@ -63,6 +63,17 @@ This model was contributed by `nielsr `__. The JA contributed by `kamalkraj `__. The original code can be found `here `__. + +BEiT specific outputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.models.beit.modeling_beit.BeitModelOutputWithPooling + :members: + +.. autoclass:: transformers.models.beit.modeling_flax_beit.FlaxBeitModelOutputWithPooling + :members: + + BeitConfig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 2b8d6884ba..a5cca41b0c 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -17,6 +17,7 @@ import collections.abc import math +from dataclasses import dataclass import torch import torch.utils.checkpoint @@ -42,6 +43,32 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +@dataclass +class BeitModelOutputWithPooling(BaseModelOutputWithPooling): + """ + Class for outputs of :class:`~transformers.BeitModel`. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if + `config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token + will be returned. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + # Inspired by # https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py # From PyTorch internals @@ -585,7 +612,7 @@ class BeitModel(BeitPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values=None, @@ -646,7 +673,7 @@ class BeitModel(BeitPreTrainedModel): if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BeitModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py index 1b1ddb7ea1..0f4b8b5abe 100644 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ b/src/transformers/models/beit/modeling_flax_beit.py @@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Tuple import numpy as np +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -40,6 +41,29 @@ from ...modeling_flax_utils import ( from .configuration_beit import BeitConfig +@flax.struct.dataclass +class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling): + """ + Class for outputs of :class:`~transformers.FlaxBeitModel`. + + Args: + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if + `config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token + will be returned. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each + layer plus the initial embedding outputs. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + BEIT_START_DOCSTRING = r""" This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the @@ -674,7 +698,7 @@ class FlaxBeitModule(nn.Module): return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] - return FlaxBaseModelOutputWithPooling( + return FlaxBeitModelOutputWithPooling( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, @@ -711,7 +735,7 @@ FLAX_BEIT_MODEL_DOCSTRING = """ """ overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBaseModelOutputWithPooling, config_class=BeitConfig) +append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig) class FlaxBeitForMaskedImageModelingModule(nn.Module):