Added Beit model output class (#14133)
* add Beit model ouput class * inherting from BaseModelOuputWithPooling * updated docs if use_mean_pooling is False * added beit specific outputs in model docs * changed the import path * Fix docs Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
This commit is contained in:
@@ -63,6 +63,17 @@ This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The JA
|
|||||||
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
|
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
|
||||||
<https://github.com/microsoft/unilm/tree/master/beit>`__.
|
<https://github.com/microsoft/unilm/tree/master/beit>`__.
|
||||||
|
|
||||||
|
|
||||||
|
BEiT specific outputs
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.beit.modeling_beit.BeitModelOutputWithPooling
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.beit.modeling_flax_beit.FlaxBeitModelOutputWithPooling
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
BeitConfig
|
BeitConfig
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import math
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -42,6 +43,32 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
|
||||||
|
"""
|
||||||
|
Class for outputs of :class:`~transformers.BeitModel`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
|
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
|
||||||
|
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
|
||||||
|
will be returned.
|
||||||
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
@@ -585,7 +612,7 @@ class BeitModel(BeitPreTrainedModel):
|
|||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
@@ -646,7 +673,7 @@ class BeitModel(BeitPreTrainedModel):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BeitModelOutputWithPooling(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
pooler_output=pooled_output,
|
pooler_output=pooled_output,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import flax
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -40,6 +41,29 @@ from ...modeling_flax_utils import (
|
|||||||
from .configuration_beit import BeitConfig
|
from .configuration_beit import BeitConfig
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling):
|
||||||
|
"""
|
||||||
|
Class for outputs of :class:`~transformers.FlaxBeitModel`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
|
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
|
||||||
|
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
|
||||||
|
will be returned.
|
||||||
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
|
shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each
|
||||||
|
layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
||||||
|
the self-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
BEIT_START_DOCSTRING = r"""
|
BEIT_START_DOCSTRING = r"""
|
||||||
|
|
||||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||||
@@ -674,7 +698,7 @@ class FlaxBeitModule(nn.Module):
|
|||||||
return (hidden_states,) + outputs[1:]
|
return (hidden_states,) + outputs[1:]
|
||||||
return (hidden_states, pooled) + outputs[1:]
|
return (hidden_states, pooled) + outputs[1:]
|
||||||
|
|
||||||
return FlaxBaseModelOutputWithPooling(
|
return FlaxBeitModelOutputWithPooling(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
pooler_output=pooled,
|
pooler_output=pooled,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
@@ -711,7 +735,7 @@ FLAX_BEIT_MODEL_DOCSTRING = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
|
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
|
||||||
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBaseModelOutputWithPooling, config_class=BeitConfig)
|
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBeitForMaskedImageModelingModule(nn.Module):
|
class FlaxBeitForMaskedImageModelingModule(nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user