FlaxBart (#11537)
* Start working on FlaxBart
* Create modeling_flax_bart.py
* Write FlaxBartAttention
* Add FlaxBartEncoderLayer
* Add FlaxBartDecoderLayer and some typing
* Add helepr function for FlaxBart
* shift_tokens_right
* _make_causal_mask
* _expand_mask
* Add PositionalEmbedding and fix init_std naming
* Add FlaxBartPretrainedModel
* Add FlaxBartEncoder
* Add FlaxBartEncoder
* Add FlaxBartEncoder among modules to be imported
* YET WE CANNOT INITIALIZE THAT!! :(
* Make BartEncoder working
Change BartEncoder to instance of nn.Module so far
* Add FlaxBartDecoder
* Add FlaxBartModel
* TODO to make model run -> Prepapre model inputs
* Resolve padding
* Add FlaxBartModel
* Add FlaxBartModel into importable modules
* Remove FlaxBartEncoder and FlaxBartDecoder from importable modules
* make style; not properly working
* make style; make quality not pass due to some import I left
* Remove TODO for padding_idx in nn.Embed so far
* Add FlaxBartForConditionalGeneration
* Incorporate Flax model output classes, i.e. return_dict
* Add another models and incorporate use_cache arg
* Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering
* Incorporate use_cache arg from PyTorch implementation
* Add all necessary Flax output utils
* Add FlaxBartForCausalLM; not working yet'
* Add minor improvements; still lacks some functionality
* Update docs, src and tests
* Add support of FlaxBart to docs/source
* Fix some bugs in FlaxBart souce code
* Add some neccessary tests for FlaxBart models - jit_compilation not passing
* Fix tests and add test_head_masking
* Fix tests for @jax.jit computation
* Add test_head_masking
* Migrate FlaxBart tests from jax.numpy to numpy
* Remove FlaxBartForCausalLM
* Clean repo
* fix bart model weight structure
* Fix FlaxBartForSequenceClassification
Slicing is not possible to use below jit, therefore, selecting sentence
representation from hidden_states must be changed.
* Allow FlaxBartForSequenceClassification for testing pt_flax equivalence
* Allow testing for FlaxBartForQA for pt_flax equivalence
* Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6
* remove past_key_values
* remove inputs_mebeds and make input_ids required
* add position ids
* re-write attention layer
* fix dataclass
* fix pos embeds and attention output
* fix pos embeds
* expose encode method
* expose decode method
* move docstring to top
* add cache for causal attn layer
* remove head masking for now
* s2s greedy search first pass
* boom boom
* fix typos
* fix greedy generate for bart
* use encoder, decoder layers instead of num_hidden_layers
* handle encoder_outputs
* cleanup
* simplify decoding
* more clean-up
* typos
* Change header + add {decoder_,}position_ids into 2 models
* add BartConfig
* fix existing tests
* add encode, decode methods
* Fix shift_tokens_right for JIT compilation + clarify one condition
* fix decode
* encoder => encode
* simplify generate
* add tests for encode and decode
* style
* add tests for cache
* fix equivalence tests
* sample generate now works with seq2seq
* generation tests
* initialize dense layers
* docstring and cleanup
* quality
* remove get/set input_embeddings
* address Patricks suggestions
* decode for every model, remove encoder_outputs from call
* update tests accordingly
* decode returns only decoder outputs and logits
* fix arguments
* doc encode, decode methods
* correct base_model_prefix
* fix test for seq classif model
* fix docs
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -299,7 +299,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
+=============================+================+================+=================+====================+==============+
|
+=============================+================+================+=================+====================+==============+
|
||||||
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| BART | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| BART | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ BartForQuestionAnswering
|
|||||||
.. autoclass:: transformers.BartForQuestionAnswering
|
.. autoclass:: transformers.BartForQuestionAnswering
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
BartForCausalLM
|
BartForCausalLM
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -138,7 +139,6 @@ BartForCausalLM
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TFBartModel
|
TFBartModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -151,3 +151,32 @@ TFBartForConditionalGeneration
|
|||||||
|
|
||||||
.. autoclass:: transformers.TFBartForConditionalGeneration
|
.. autoclass:: transformers.TFBartForConditionalGeneration
|
||||||
:members: call
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBartModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBartModel
|
||||||
|
:members: __call__, encode, decode
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBartForConditionalGeneration
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBartForConditionalGeneration
|
||||||
|
:members: __call__, encode, decode
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBartForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBartForSequenceClassification
|
||||||
|
:members: __call__, encode, decode
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBartForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBartForQuestionAnswering
|
||||||
|
:members: __call__, encode, decode
|
||||||
|
|
||||||
|
|||||||
@@ -1508,6 +1508,14 @@ if is_flax_available():
|
|||||||
"FlaxAutoModelForTokenClassification",
|
"FlaxAutoModelForTokenClassification",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.bart"].extend(
|
||||||
|
[
|
||||||
|
"FlaxBartForConditionalGeneration",
|
||||||
|
"FlaxBartForQuestionAnswering",
|
||||||
|
"FlaxBartForSequenceClassification",
|
||||||
|
"FlaxBartModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.bert"].extend(
|
_import_structure["models.bert"].extend(
|
||||||
[
|
[
|
||||||
"FlaxBertForMaskedLM",
|
"FlaxBertForMaskedLM",
|
||||||
@@ -2808,6 +2816,12 @@ if TYPE_CHECKING:
|
|||||||
FlaxAutoModelForSequenceClassification,
|
FlaxAutoModelForSequenceClassification,
|
||||||
FlaxAutoModelForTokenClassification,
|
FlaxAutoModelForTokenClassification,
|
||||||
)
|
)
|
||||||
|
from .models.bart import (
|
||||||
|
FlaxBartForConditionalGeneration,
|
||||||
|
FlaxBartForQuestionAnswering,
|
||||||
|
FlaxBartForSequenceClassification,
|
||||||
|
FlaxBartModel,
|
||||||
|
)
|
||||||
from .models.bert import (
|
from .models.bert import (
|
||||||
FlaxBertForMaskedLM,
|
FlaxBertForMaskedLM,
|
||||||
FlaxBertForMultipleChoice,
|
FlaxBertForMultipleChoice,
|
||||||
|
|||||||
@@ -101,12 +101,23 @@ class FlaxGenerationMixin:
|
|||||||
state = body_fn(state)
|
state = body_fn(state)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
|
||||||
|
encoder_kwargs = {
|
||||||
|
argument: value
|
||||||
|
for argument, value in model_kwargs.items()
|
||||||
|
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
||||||
|
}
|
||||||
|
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: jax_xla.DeviceArray,
|
input_ids: jax_xla.DeviceArray,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
|
bos_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
|
decoder_start_token_id: Optional[int] = None,
|
||||||
do_sample: Optional[bool] = None,
|
do_sample: Optional[bool] = None,
|
||||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -147,6 +158,8 @@ class FlaxGenerationMixin:
|
|||||||
The id of the `beginning-of-sequence` token.
|
The id of the `beginning-of-sequence` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `end-of-sequence` token.
|
The id of the `end-of-sequence` token.
|
||||||
|
decoder_start_token_id (:obj:`int`, `optional`):
|
||||||
|
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||||
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
||||||
a considerably slower runtime.
|
a considerably slower runtime.
|
||||||
@@ -170,10 +183,23 @@ class FlaxGenerationMixin:
|
|||||||
"""
|
"""
|
||||||
# set init values
|
# set init values
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
decoder_start_token_id = (
|
||||||
|
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
|
||||||
|
)
|
||||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||||
|
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
# add encoder_outputs to model_kwargs
|
||||||
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||||
|
# prepare decoder_input_ids for generation
|
||||||
|
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||||
|
|
||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
@@ -246,10 +272,11 @@ class FlaxGenerationMixin:
|
|||||||
# per batch-item state bit indicating if sentence has finished.
|
# per batch-item state bit indicating if sentence has finished.
|
||||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||||
|
|
||||||
model = self
|
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||||
|
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||||
|
model = self.decode if self.config.is_encoder_decoder else self
|
||||||
# initialize model specific kwargs
|
# initialize model specific kwargs
|
||||||
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||||
|
|
||||||
# initialize state
|
# initialize state
|
||||||
state = GreedyState(
|
state = GreedyState(
|
||||||
@@ -277,8 +304,7 @@ class FlaxGenerationMixin:
|
|||||||
next_token = next_token[:, None]
|
next_token = next_token[:, None]
|
||||||
|
|
||||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||||
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
|
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||||
|
|
||||||
return GreedyState(
|
return GreedyState(
|
||||||
cur_len=state.cur_len + 1,
|
cur_len=state.cur_len + 1,
|
||||||
sequences=next_sequences,
|
sequences=next_sequences,
|
||||||
@@ -288,6 +314,7 @@ class FlaxGenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||||
|
if input_ids.shape[1] > 1:
|
||||||
state = greedy_search_body_fn(state)
|
state = greedy_search_body_fn(state)
|
||||||
|
|
||||||
if not trace:
|
if not trace:
|
||||||
@@ -327,10 +354,12 @@ class FlaxGenerationMixin:
|
|||||||
# per batch-item state bit indicating if sentence has finished.
|
# per batch-item state bit indicating if sentence has finished.
|
||||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||||
|
|
||||||
model = self
|
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||||
|
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||||
|
model = self.decode if self.config.is_encoder_decoder else self
|
||||||
|
|
||||||
# initialize model specific kwargs
|
# initialize model specific kwargs
|
||||||
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||||
|
|
||||||
# initialize state
|
# initialize state
|
||||||
state = SampleState(
|
state = SampleState(
|
||||||
@@ -366,7 +395,7 @@ class FlaxGenerationMixin:
|
|||||||
next_token = next_token[:, None]
|
next_token = next_token[:, None]
|
||||||
|
|
||||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||||
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
|
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||||
|
|
||||||
return SampleState(
|
return SampleState(
|
||||||
cur_len=state.cur_len + 1,
|
cur_len=state.cur_len + 1,
|
||||||
@@ -378,6 +407,7 @@ class FlaxGenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||||
|
if input_ids.shape[1] > 1:
|
||||||
state = sample_search_body_fn(state)
|
state = sample_search_body_fn(state)
|
||||||
|
|
||||||
if not trace:
|
if not trace:
|
||||||
|
|||||||
@@ -106,6 +106,154 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
||||||
|
1, hidden_size)` is output.
|
||||||
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||||
|
``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
|
||||||
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: jax_xla.DeviceArray = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxSeq2SeqModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
|
||||||
|
decoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
||||||
|
1, hidden_size)` is output.
|
||||||
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
|
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: jax_xla.DeviceArray = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for causal language model (or autoregressive) outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
|
cross-attention heads.
|
||||||
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` tuples of length :obj:`config.n_layers`, with each tuple containing the
|
||||||
|
cached key, value states of the self-attention and the cross-attention layers if model is used in
|
||||||
|
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxMaskedLMOutput(ModelOutput):
|
class FlaxMaskedLMOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -135,6 +283,63 @@ class FlaxMaskedLMOutput(ModelOutput):
|
|||||||
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxSeq2SeqLMOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for sequence-to-sequence language models outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
|
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxNextSentencePredictorOutput(ModelOutput):
|
class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -188,6 +393,63 @@ class FlaxSequenceClassifierOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of sequence-to-sequence sentence classification models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
|
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -269,3 +531,63 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
end_logits: jax_xla.DeviceArray = None
|
end_logits: jax_xla.DeviceArray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of sequence-to-sequence question answering models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
|
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
||||||
|
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
|
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
|
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
|
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
|
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
|
self-attention heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_logits: jax_xla.DeviceArray = None
|
||||||
|
end_logits: jax_xla.DeviceArray = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|||||||
@@ -18,6 +18,12 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..bart.modeling_flax_bart import (
|
||||||
|
FlaxBartForConditionalGeneration,
|
||||||
|
FlaxBartForQuestionAnswering,
|
||||||
|
FlaxBartForSequenceClassification,
|
||||||
|
FlaxBartModel,
|
||||||
|
)
|
||||||
from ..bert.modeling_flax_bert import (
|
from ..bert.modeling_flax_bert import (
|
||||||
FlaxBertForMaskedLM,
|
FlaxBertForMaskedLM,
|
||||||
FlaxBertForMultipleChoice,
|
FlaxBertForMultipleChoice,
|
||||||
@@ -49,7 +55,7 @@ from ..roberta.modeling_flax_roberta import (
|
|||||||
)
|
)
|
||||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||||
from .auto_factory import auto_class_factory
|
from .auto_factory import auto_class_factory
|
||||||
from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
|
from .configuration_auto import BartConfig, BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -60,6 +66,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
|||||||
# Base model mapping
|
# Base model mapping
|
||||||
(RobertaConfig, FlaxRobertaModel),
|
(RobertaConfig, FlaxRobertaModel),
|
||||||
(BertConfig, FlaxBertModel),
|
(BertConfig, FlaxBertModel),
|
||||||
|
(BartConfig, FlaxBartModel),
|
||||||
(GPT2Config, FlaxGPT2Model),
|
(GPT2Config, FlaxGPT2Model),
|
||||||
(ElectraConfig, FlaxElectraModel),
|
(ElectraConfig, FlaxElectraModel),
|
||||||
(CLIPConfig, FlaxCLIPModel),
|
(CLIPConfig, FlaxCLIPModel),
|
||||||
@@ -72,6 +79,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
# Model for pre-training mapping
|
# Model for pre-training mapping
|
||||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||||
(BertConfig, FlaxBertForPreTraining),
|
(BertConfig, FlaxBertForPreTraining),
|
||||||
|
(BartConfig, FlaxBartForConditionalGeneration),
|
||||||
(ElectraConfig, FlaxElectraForPreTraining),
|
(ElectraConfig, FlaxElectraForPreTraining),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -81,6 +89,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||||||
# Model for Masked LM mapping
|
# Model for Masked LM mapping
|
||||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||||
(BertConfig, FlaxBertForMaskedLM),
|
(BertConfig, FlaxBertForMaskedLM),
|
||||||
|
(BartConfig, FlaxBartForConditionalGeneration),
|
||||||
(ElectraConfig, FlaxElectraForMaskedLM),
|
(ElectraConfig, FlaxElectraForMaskedLM),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -104,6 +113,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
# Model for Sequence Classification mapping
|
# Model for Sequence Classification mapping
|
||||||
(RobertaConfig, FlaxRobertaForSequenceClassification),
|
(RobertaConfig, FlaxRobertaForSequenceClassification),
|
||||||
(BertConfig, FlaxBertForSequenceClassification),
|
(BertConfig, FlaxBertForSequenceClassification),
|
||||||
|
(BartConfig, FlaxBartForSequenceClassification),
|
||||||
(ElectraConfig, FlaxElectraForSequenceClassification),
|
(ElectraConfig, FlaxElectraForSequenceClassification),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -113,6 +123,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||||||
# Model for Question Answering mapping
|
# Model for Question Answering mapping
|
||||||
(RobertaConfig, FlaxRobertaForQuestionAnswering),
|
(RobertaConfig, FlaxRobertaForQuestionAnswering),
|
||||||
(BertConfig, FlaxBertForQuestionAnswering),
|
(BertConfig, FlaxBertForQuestionAnswering),
|
||||||
|
(BartConfig, FlaxBartForQuestionAnswering),
|
||||||
(ElectraConfig, FlaxElectraForQuestionAnswering),
|
(ElectraConfig, FlaxElectraForQuestionAnswering),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
|
from ...file_utils import (
|
||||||
|
_BaseLazyModule,
|
||||||
|
is_flax_available,
|
||||||
|
is_tf_available,
|
||||||
|
is_tokenizers_available,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -43,6 +49,13 @@ if is_torch_available():
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]
|
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
_import_structure["modeling_flax_bart"] = [
|
||||||
|
"FlaxBartForConditionalGeneration",
|
||||||
|
"FlaxBartForQuestionAnswering",
|
||||||
|
"FlaxBartForSequenceClassification",
|
||||||
|
"FlaxBartModel",
|
||||||
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||||
@@ -66,6 +79,14 @@ if TYPE_CHECKING:
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
|
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from .modeling_flax_bart import (
|
||||||
|
FlaxBartForConditionalGeneration,
|
||||||
|
FlaxBartForQuestionAnswering,
|
||||||
|
FlaxBartForSequenceClassification,
|
||||||
|
FlaxBartModel,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|||||||
1726
src/transformers/models/bart/modeling_flax_bart.py
Normal file
1726
src/transformers/models/bart/modeling_flax_bart.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -149,6 +149,42 @@ class FlaxAutoModelForTokenClassification:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBartForConditionalGeneration:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBartForQuestionAnswering:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBartForSequenceClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBartModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLM:
|
class FlaxBertForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
417
tests/test_modeling_flax_bart.py
Normal file
417
tests/test_modeling_flax_bart.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
|
from transformers import BartConfig, is_flax_available
|
||||||
|
from transformers.testing_utils import require_flax, slow
|
||||||
|
|
||||||
|
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||||
|
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import os
|
||||||
|
|
||||||
|
# The slow tests are often failing with OOM error on GPU
|
||||||
|
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
|
||||||
|
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
|
||||||
|
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from transformers.models.bart.modeling_flax_bart import (
|
||||||
|
FlaxBartForConditionalGeneration,
|
||||||
|
FlaxBartForQuestionAnswering,
|
||||||
|
FlaxBartForSequenceClassification,
|
||||||
|
FlaxBartModel,
|
||||||
|
shift_tokens_right,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_bart_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0)
|
||||||
|
if head_mask is None:
|
||||||
|
head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||||
|
if decoder_head_mask is None:
|
||||||
|
decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
if cross_attn_head_mask is None:
|
||||||
|
cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBartModelTester(unittest.TestCase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=16,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=4,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=32,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
|
||||||
|
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
|
||||||
|
|
||||||
|
decoder_input_ids = shift_tokens_right(input_ids, 1, 2)
|
||||||
|
|
||||||
|
config = BartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def check_use_cache_forward(self, model_class_name, config, inputs_dict):
|
||||||
|
max_decoder_length = 20
|
||||||
|
model = model_class_name(config)
|
||||||
|
|
||||||
|
encoder_outputs = model.encode(inputs_dict["input_ids"])
|
||||||
|
|
||||||
|
decoder_input_ids, decoder_attention_mask = (
|
||||||
|
inputs_dict["decoder_input_ids"],
|
||||||
|
inputs_dict["decoder_attention_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
|
||||||
|
decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4")
|
||||||
|
|
||||||
|
decoder_position_ids = jnp.broadcast_to(
|
||||||
|
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
|
||||||
|
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
|
||||||
|
)
|
||||||
|
outputs_cache = model.decode(
|
||||||
|
decoder_input_ids[:, :-1],
|
||||||
|
encoder_outputs,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
|
||||||
|
outputs_cache_next = model.decode(
|
||||||
|
decoder_input_ids[:, -1:],
|
||||||
|
encoder_outputs,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
past_key_values=outputs_cache.past_key_values,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||||
|
|
||||||
|
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||||
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
||||||
|
max_decoder_length = 20
|
||||||
|
model = model_class_name(config)
|
||||||
|
|
||||||
|
encoder_outputs = model.encode(inputs_dict["input_ids"])
|
||||||
|
|
||||||
|
decoder_input_ids, decoder_attention_mask = (
|
||||||
|
inputs_dict["decoder_input_ids"],
|
||||||
|
inputs_dict["decoder_attention_mask"],
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_attention_mask_cache = jnp.concatenate(
|
||||||
|
[
|
||||||
|
decoder_attention_mask,
|
||||||
|
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
|
||||||
|
decoder_position_ids = jnp.broadcast_to(
|
||||||
|
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
|
||||||
|
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_cache = model.decode(
|
||||||
|
decoder_input_ids[:, :-1],
|
||||||
|
encoder_outputs,
|
||||||
|
decoder_attention_mask=decoder_attention_mask_cache,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
|
)
|
||||||
|
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
|
||||||
|
outputs_cache_next = model.decode(
|
||||||
|
decoder_input_ids[:, -1:],
|
||||||
|
encoder_outputs,
|
||||||
|
past_key_values=outputs_cache.past_key_values,
|
||||||
|
decoder_attention_mask=decoder_attention_mask_cache,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
||||||
|
|
||||||
|
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||||
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class BartHeadTests(unittest.TestCase):
|
||||||
|
vocab_size = 99
|
||||||
|
|
||||||
|
def _get_config_and_data(self):
|
||||||
|
input_ids = np.array(
|
||||||
|
[
|
||||||
|
[71, 82, 18, 33, 46, 91, 2],
|
||||||
|
[68, 34, 26, 58, 30, 82, 2],
|
||||||
|
[5, 97, 17, 39, 94, 40, 2],
|
||||||
|
[76, 83, 94, 25, 70, 78, 2],
|
||||||
|
[87, 59, 41, 35, 48, 66, 2],
|
||||||
|
[55, 13, 16, 58, 5, 2, 1], # note padding
|
||||||
|
[64, 27, 31, 51, 12, 75, 2],
|
||||||
|
[52, 64, 86, 17, 83, 39, 2],
|
||||||
|
[48, 61, 9, 24, 71, 82, 2],
|
||||||
|
[26, 1, 60, 48, 22, 13, 2],
|
||||||
|
[21, 5, 62, 28, 14, 76, 2],
|
||||||
|
[45, 98, 37, 86, 59, 48, 2],
|
||||||
|
[70, 70, 50, 9, 28, 0, 2],
|
||||||
|
],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
config = BartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=24,
|
||||||
|
encoder_layers=2,
|
||||||
|
decoder_layers=2,
|
||||||
|
encoder_attention_heads=2,
|
||||||
|
decoder_attention_heads=2,
|
||||||
|
encoder_ffn_dim=32,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
max_position_embeddings=48,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
)
|
||||||
|
return config, input_ids, batch_size
|
||||||
|
|
||||||
|
def test_sequence_classification_forward(self):
|
||||||
|
config, input_ids, batch_size = self._get_config_and_data()
|
||||||
|
model = FlaxBartForSequenceClassification(config)
|
||||||
|
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||||
|
expected_shape = (batch_size, config.num_labels)
|
||||||
|
self.assertEqual(outputs["logits"].shape, expected_shape)
|
||||||
|
|
||||||
|
def test_question_answering_forward(self):
|
||||||
|
config, input_ids, batch_size = self._get_config_and_data()
|
||||||
|
model = FlaxBartForQuestionAnswering(config)
|
||||||
|
outputs = model(input_ids=input_ids)
|
||||||
|
|
||||||
|
self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
|
||||||
|
self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
|
||||||
|
|
||||||
|
# @timeout_decorator.timeout(1) # not working with the decorator so far
|
||||||
|
def test_lm_forward(self):
|
||||||
|
config, input_ids, batch_size = self._get_config_and_data()
|
||||||
|
lm_model = FlaxBartForConditionalGeneration(config)
|
||||||
|
outputs = lm_model(input_ids=input_ids)
|
||||||
|
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||||
|
self.assertEqual(outputs["logits"].shape, expected_shape)
|
||||||
|
|
||||||
|
def test_lm_uneven_forward(self):
|
||||||
|
config = BartConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=14,
|
||||||
|
encoder_layers=2,
|
||||||
|
decoder_layers=2,
|
||||||
|
encoder_attention_heads=2,
|
||||||
|
decoder_attention_heads=2,
|
||||||
|
encoder_ffn_dim=8,
|
||||||
|
decoder_ffn_dim=8,
|
||||||
|
max_position_embeddings=48,
|
||||||
|
)
|
||||||
|
lm_model = FlaxBartForConditionalGeneration(config)
|
||||||
|
context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64)
|
||||||
|
summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64)
|
||||||
|
outputs = lm_model(input_ids=context, decoder_input_ids=summary)
|
||||||
|
expected_shape = (*summary.shape, config.vocab_size)
|
||||||
|
self.assertEqual(outputs["logits"].shape, expected_shape)
|
||||||
|
|
||||||
|
def test_shift_tokens_right(self):
|
||||||
|
input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64)
|
||||||
|
shifted = shift_tokens_right(input_ids, 1, 2)
|
||||||
|
n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum()
|
||||||
|
n_pad_after = np.equal(shifted, 1).astype(np.float32).sum()
|
||||||
|
self.assertEqual(shifted.shape, input_ids.shape)
|
||||||
|
self.assertEqual(n_pad_after, n_pad_before - 1)
|
||||||
|
self.assertTrue(np.equal(shifted[:, 0], 2).all())
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
|
||||||
|
is_encoder_decoder = True
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
FlaxBartModel,
|
||||||
|
FlaxBartForConditionalGeneration,
|
||||||
|
FlaxBartForSequenceClassification,
|
||||||
|
FlaxBartForQuestionAnswering,
|
||||||
|
)
|
||||||
|
if is_flax_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (FlaxBartForConditionalGeneration,) if is_flax_available() else ()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = FlaxBartModelTester(self)
|
||||||
|
|
||||||
|
def test_use_cache_forward(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
|
||||||
|
|
||||||
|
def test_use_cache_forward_with_attn_mask(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
|
||||||
|
|
||||||
|
def test_encode(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def encode_jitted(input_ids, attention_mask=None, **kwargs):
|
||||||
|
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
with self.subTest("JIT Enabled"):
|
||||||
|
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
with self.subTest("JIT Disabled"):
|
||||||
|
with jax.disable_jit():
|
||||||
|
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
def test_decode(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
model = model_class(config)
|
||||||
|
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
|
||||||
|
|
||||||
|
prepared_inputs_dict = {
|
||||||
|
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||||
|
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
|
||||||
|
"encoder_outputs": encoder_outputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
|
||||||
|
return model.decode(
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest("JIT Enabled"):
|
||||||
|
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
with self.subTest("JIT Disabled"):
|
||||||
|
with jax.disable_jit():
|
||||||
|
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_class_name in self.all_model_classes:
|
||||||
|
model = model_class_name.from_pretrained("facebook/bart-base", from_pt=True)
|
||||||
|
# FlaxBartForSequenceClassification expects eos token in input_ids
|
||||||
|
input_ids = np.ones((1, 1)) * model.config.eos_token_id
|
||||||
|
outputs = model(input_ids)
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
@@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import is_flax_available, is_torch_available
|
from transformers import is_flax_available, is_torch_available
|
||||||
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ if is_flax_available():
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
import jaxlib.xla_extension as jax_xla
|
||||||
|
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
@@ -42,6 +44,14 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _config_zero_init(config):
|
||||||
|
configs_no_init = copy.deepcopy(config)
|
||||||
|
for key in configs_no_init.__dict__.keys():
|
||||||
|
if "_range" in key or "_std" in key or "initializer_factor" in key:
|
||||||
|
setattr(configs_no_init, key, 1e-10)
|
||||||
|
return configs_no_init
|
||||||
|
|
||||||
|
|
||||||
def ids_tensor(shape, vocab_size, rng=None):
|
def ids_tensor(shape, vocab_size, rng=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
if rng is None:
|
if rng is None:
|
||||||
@@ -87,6 +97,7 @@ def random_attention_mask(shape, rng=None):
|
|||||||
class FlaxModelTesterMixin:
|
class FlaxModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class):
|
||||||
inputs_dict = copy.deepcopy(inputs_dict)
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
@@ -156,6 +167,9 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
|
|
||||||
pt_model = pt_model_class(config).eval()
|
pt_model = pt_model_class(config).eval()
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
fx_model = model_class(config, dtype=jnp.float32)
|
fx_model = model_class(config, dtype=jnp.float32)
|
||||||
|
|
||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
@@ -167,7 +181,7 @@ class FlaxModelTesterMixin:
|
|||||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
@@ -178,7 +192,10 @@ class FlaxModelTesterMixin:
|
|||||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
)
|
)
|
||||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
if not isinstance(
|
||||||
|
fx_output_loaded, tuple
|
||||||
|
): # TODO(Patrick, Daniel) - let's discard use_cache for now
|
||||||
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_flax_to_pt(self):
|
def test_equivalence_flax_to_pt(self):
|
||||||
@@ -195,6 +212,9 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
|
|
||||||
pt_model = pt_model_class(config).eval()
|
pt_model = pt_model_class(config).eval()
|
||||||
|
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||||
|
# So we disable `use_cache` here for PyTorch model.
|
||||||
|
pt_model.config.use_cache = False
|
||||||
fx_model = model_class(config, dtype=jnp.float32)
|
fx_model = model_class(config, dtype=jnp.float32)
|
||||||
|
|
||||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
@@ -207,8 +227,9 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
fx_model.save_pretrained(tmpdirname)
|
fx_model.save_pretrained(tmpdirname)
|
||||||
@@ -221,7 +242,8 @@ class FlaxModelTesterMixin:
|
|||||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||||
)
|
)
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
if not isinstance(fx_output, tuple): # TODO(Patrick, Daniel) - let's discard use_cache for now
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||||
|
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -276,6 +298,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
@@ -287,6 +310,15 @@ class FlaxModelTesterMixin:
|
|||||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
arg_names = [*signature.parameters.keys()]
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
expected_arg_names = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"decoder_input_ids",
|
||||||
|
"decoder_attention_mask",
|
||||||
|
]
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
else:
|
||||||
expected_arg_names = ["input_ids", "attention_mask"]
|
expected_arg_names = ["input_ids", "attention_mask"]
|
||||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||||
|
|
||||||
@@ -306,9 +338,16 @@ class FlaxModelTesterMixin:
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
hidden_states = outputs.hidden_states
|
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||||
|
|
||||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
expected_num_layers = getattr(
|
||||||
|
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||||
|
)
|
||||||
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
|
||||||
|
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||||
|
seq_length = self.model_tester.encoder_seq_length
|
||||||
|
else:
|
||||||
seq_length = self.model_tester.seq_length
|
seq_length = self.model_tester.seq_length
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -316,6 +355,19 @@ class FlaxModelTesterMixin:
|
|||||||
[seq_length, self.model_tester.hidden_size],
|
[seq_length, self.model_tester.hidden_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
hidden_states = outputs.decoder_hidden_states
|
||||||
|
|
||||||
|
self.assertIsInstance(hidden_states, (list, tuple))
|
||||||
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(hidden_states[0].shape[-2:]),
|
||||||
|
[decoder_seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
@@ -333,13 +385,17 @@ class FlaxModelTesterMixin:
|
|||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|
||||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||||
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
|
||||||
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||||
|
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||||
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
inputs_dict["output_attentions"] = True
|
inputs_dict["output_attentions"] = True
|
||||||
inputs_dict["output_hidden_states"] = False
|
inputs_dict["output_hidden_states"] = False
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
attentions = outputs.attentions
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
# check that output_attentions also work using config
|
# check that output_attentions also work using config
|
||||||
@@ -347,21 +403,57 @@ class FlaxModelTesterMixin:
|
|||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
attentions = outputs.attentions
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(attentions[0].shape[-3:]),
|
list(attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
out_len = len(outputs)
|
out_len = len(outputs)
|
||||||
|
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
correct_outlen = 5
|
||||||
|
|
||||||
|
# Question Answering model returns start_logits and end_logits
|
||||||
|
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||||
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||||
|
|
||||||
|
self.assertEqual(out_len, correct_outlen)
|
||||||
|
|
||||||
|
# decoder attentions
|
||||||
|
decoder_attentions = outputs.decoder_attentions
|
||||||
|
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||||
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(decoder_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross attentions
|
||||||
|
cross_attentions = outputs.cross_attentions
|
||||||
|
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||||
|
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(cross_attentions[0].shape[-3:]),
|
||||||
|
[
|
||||||
|
self.model_tester.num_attention_heads,
|
||||||
|
decoder_seq_length,
|
||||||
|
encoder_key_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Check attention is always last and order is fine
|
# Check attention is always last and order is fine
|
||||||
inputs_dict["output_attentions"] = True
|
inputs_dict["output_attentions"] = True
|
||||||
inputs_dict["output_hidden_states"] = True
|
inputs_dict["output_hidden_states"] = True
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||||
|
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||||
|
elif self.is_encoder_decoder:
|
||||||
|
added_hidden_states = 2
|
||||||
|
else:
|
||||||
added_hidden_states = 1
|
added_hidden_states = 1
|
||||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
|
|
||||||
@@ -370,5 +462,5 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(self_attentions[0].shape[-3:]),
|
list(self_attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user