Added Beit model output class (#14133)
* add Beit model ouput class * inherting from BaseModelOuputWithPooling * updated docs if use_mean_pooling is False * added beit specific outputs in model docs * changed the import path * Fix docs Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
This commit is contained in:
@@ -63,6 +63,17 @@ This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The JA
|
||||
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
|
||||
<https://github.com/microsoft/unilm/tree/master/beit>`__.
|
||||
|
||||
|
||||
BEiT specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.models.beit.modeling_beit.BeitModelOutputWithPooling
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.beit.modeling_flax_beit.FlaxBeitModelOutputWithPooling
|
||||
:members:
|
||||
|
||||
|
||||
BeitConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -42,6 +43,32 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
|
||||
"""
|
||||
Class for outputs of :class:`~transformers.BeitModel`.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
|
||||
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
|
||||
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
|
||||
will be returned.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
||||
# From PyTorch internals
|
||||
@@ -585,7 +612,7 @@ class BeitModel(BeitPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
@@ -646,7 +673,7 @@ class BeitModel(BeitPreTrainedModel):
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
return BeitModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -40,6 +41,29 @@ from ...modeling_flax_utils import (
|
||||
from .configuration_beit import BeitConfig
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling):
|
||||
"""
|
||||
Class for outputs of :class:`~transformers.FlaxBeitModel`.
|
||||
|
||||
Args:
|
||||
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
|
||||
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
|
||||
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
|
||||
will be returned.
|
||||
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each
|
||||
layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
||||
the self-attention heads.
|
||||
"""
|
||||
|
||||
|
||||
BEIT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||
@@ -674,7 +698,7 @@ class FlaxBeitModule(nn.Module):
|
||||
return (hidden_states,) + outputs[1:]
|
||||
return (hidden_states, pooled) + outputs[1:]
|
||||
|
||||
return FlaxBaseModelOutputWithPooling(
|
||||
return FlaxBeitModelOutputWithPooling(
|
||||
last_hidden_state=hidden_states,
|
||||
pooler_output=pooled,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@@ -711,7 +735,7 @@ FLAX_BEIT_MODEL_DOCSTRING = """
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
|
||||
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBaseModelOutputWithPooling, config_class=BeitConfig)
|
||||
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)
|
||||
|
||||
|
||||
class FlaxBeitForMaskedImageModelingModule(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user