Add TFBartForConditionalGeneration (#5411)
* half done * doc improvement * Cp test file * brokedn * broken test * undo some mess * ckpt * borked * Halfway * 6 passing * boom boom * Much progress but still 6 * boom boom * merged master * 10 passing * boom boom * Style * no t5 changes * 13 passing * Integration test failing, but not gibberish * Frustrated * Merged master * 4 fail * 4 fail * fix return_dict * boom boom * Still only 4 * prepare method * prepare method * before delete classif * Skip tests to avoid adding boilerplate * boom boom * fast tests passing * style * boom boom * Switch to supporting many input types * remove FIXMENORM * working * Fixed past_key_values/decoder_cached_states confusion * new broken test * Fix attention mask kwarg name * undo accidental * Style and reviewers * style * Docs and common tests * Cleaner assert messages * copy docs * style issues * Sphinx fix * Simplify caching logic * test does not require torch * copy _NoLayerEmbedTokens * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update tests/test_modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Line length and dont document None * Add pipeline test coverage * assert msg * At parity * Assert messages * mark slow * Update compile test * back in init * Merge master * Fix tests Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -86,3 +86,18 @@ BartForQuestionAnswering
|
||||
|
||||
.. autoclass:: transformers.BartForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
|
||||
TFBartModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFBartModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFBartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFBartForConditionalGeneration
|
||||
:members: call
|
||||
|
||||
@@ -652,6 +652,7 @@ if is_tf_available():
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelWithLMHead,
|
||||
)
|
||||
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
|
||||
from .modeling_tf_bert import (
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFBertEmbeddings,
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
|
||||
from transformers import (
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@@ -37,6 +38,7 @@ from transformers import (
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
AlbertConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
@@ -49,6 +51,7 @@ from transformers import (
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
TFAlbertForPreTraining,
|
||||
TFBartForConditionalGeneration,
|
||||
TFBertForPreTraining,
|
||||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
@@ -87,6 +90,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
AlbertForPreTraining,
|
||||
BartForConditionalGeneration,
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
@@ -113,6 +117,12 @@ if is_torch_available():
|
||||
logging.set_verbosity_info()
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bart": (
|
||||
BartConfig,
|
||||
TFBartForConditionalGeneration,
|
||||
BartForConditionalGeneration,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
),
|
||||
"bert": (
|
||||
BertConfig,
|
||||
TFBertForPreTraining,
|
||||
|
||||
@@ -640,6 +640,10 @@ class TFGenerationMixin:
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
# calculate log softmax score
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
@@ -890,6 +894,13 @@ class TFGenerationMixin:
|
||||
def _reorder_cache(past, beam_idx):
|
||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
||||
|
||||
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
the generate method.
|
||||
"""
|
||||
return logits
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
# create logit penalties for already seen input_ids
|
||||
|
||||
@@ -131,7 +131,7 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a
|
||||
sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
|
||||
the decoder.
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last
|
||||
@@ -217,12 +217,6 @@ def _make_linear_from_emb(emb):
|
||||
return lin_layer
|
||||
|
||||
|
||||
# Helper Functions, mostly for making masks
|
||||
def _check_shapes(shape_1, shape2):
|
||||
if shape_1 != shape2:
|
||||
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
||||
|
||||
|
||||
def shift_tokens_right(input_ids, pad_token_id):
|
||||
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
||||
prev_output_tokens = input_ids.clone()
|
||||
@@ -595,7 +589,7 @@ class BartDecoder(nn.Module):
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = []
|
||||
next_decoder_cache: List[Dict] = []
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@@ -640,7 +634,7 @@ class BartDecoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def _reorder_buffer(attn_cache, new_order):
|
||||
def _reorder_buffer(attn_cache: Dict, new_order) -> Dict:
|
||||
for k, input_buffer_k in attn_cache.items():
|
||||
if input_buffer_k is not None:
|
||||
attn_cache[k] = input_buffer_k.index_select(0, new_order)
|
||||
@@ -679,17 +673,15 @@ class Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key: Optional[Tensor],
|
||||
key: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||
layer_state: Optional[Dict[str, Tensor]] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
output_attentions=False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||
static_kv: bool = self.encoder_decoder_attention
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
# get here for encoder decoder cause of static_kv
|
||||
if layer_state is not None: # reuse k,v and encoder_padding_mask
|
||||
saved_state = layer_state.get(self.cache_key, {})
|
||||
@@ -697,17 +689,16 @@ class Attention(nn.Module):
|
||||
# previous time steps are cached - no need to recompute key and value if they are static
|
||||
key = None
|
||||
else:
|
||||
# this branch is hit by encoder
|
||||
saved_state = None
|
||||
layer_state = {}
|
||||
|
||||
q = self.q_proj(query) * self.scaling
|
||||
if static_kv:
|
||||
if key is None:
|
||||
if static_kv and key is None: # cross-attention with cache
|
||||
k = v = None
|
||||
else:
|
||||
elif static_kv and key is not None: # cross-attention no prev_key found in cache
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(key)
|
||||
else:
|
||||
else: # self-attention
|
||||
k = self.k_proj(query)
|
||||
v = self.v_proj(query)
|
||||
|
||||
@@ -717,18 +708,16 @@ class Attention(nn.Module):
|
||||
if v is not None:
|
||||
v = self._shape(v, -1, bsz)
|
||||
|
||||
if saved_state is not None:
|
||||
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
||||
if saved_state:
|
||||
k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
|
||||
|
||||
# Update cache
|
||||
layer_state[self.cache_key] = {
|
||||
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
||||
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
||||
}
|
||||
if isinstance(layer_state, dict):
|
||||
cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
|
||||
layer_state[self.cache_key] = dict(prev_key=k.view(*cached_shape), prev_value=v.view(*cached_shape))
|
||||
|
||||
assert k is not None
|
||||
src_len = k.size(1)
|
||||
assert key_padding_mask is None or key_padding_mask.shape == (bsz, src_len)
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
@@ -736,13 +725,7 @@ class Attention(nn.Module):
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||
key_padding_mask = None
|
||||
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
|
||||
bsz,
|
||||
src_len,
|
||||
)
|
||||
# Note: deleted workaround to get around fork/join parallelism not supporting Optional types. on 2020/10/15
|
||||
|
||||
if key_padding_mask is not None: # don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
@@ -750,11 +733,7 @@ class Attention(nn.Module):
|
||||
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_probs = F.dropout(
|
||||
attn_weights,
|
||||
p=self.dropout,
|
||||
training=self.training,
|
||||
)
|
||||
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
assert v is not None
|
||||
attn_output = torch.bmm(attn_probs, v)
|
||||
@@ -767,36 +746,13 @@ class Attention(nn.Module):
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||||
def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[Tensor]:
|
||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||
if "prev_key" in saved_state:
|
||||
_prev_key = saved_state["prev_key"]
|
||||
assert _prev_key is not None
|
||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
k = prev_key
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||
if static_kv:
|
||||
v = prev_value
|
||||
else:
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
assert k is not None and v is not None
|
||||
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
|
||||
if prev_key_padding_mask is not None:
|
||||
if static_kv:
|
||||
new_key_padding_mask = prev_key_padding_mask
|
||||
else:
|
||||
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
|
||||
else:
|
||||
new_key_padding_mask = key_padding_mask
|
||||
return k, v, new_key_padding_mask
|
||||
prev_K = saved_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
prev_V = saved_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
new_K = prev_K if static_kv else torch.cat([prev_K, k], dim=1)
|
||||
new_V = prev_V if static_kv else torch.cat([prev_V, v], dim=1)
|
||||
return new_K, new_V
|
||||
|
||||
|
||||
class BartClassificationHead(nn.Module):
|
||||
@@ -1143,14 +1099,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
||||
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_id) -> None:
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
|
||||
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
||||
@@ -52,5 +52,5 @@ class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@@ -51,5 +51,5 @@ class MarianMTModel(BartForConditionalGeneration):
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@@ -21,6 +21,7 @@ from collections import OrderedDict
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
@@ -51,6 +52,7 @@ from .modeling_tf_albert import (
|
||||
TFAlbertForTokenClassification,
|
||||
TFAlbertModel,
|
||||
)
|
||||
from .modeling_tf_bart import TFBartForConditionalGeneration
|
||||
from .modeling_tf_bert import (
|
||||
TFBertForMaskedLM,
|
||||
TFBertForMultipleChoice,
|
||||
@@ -206,6 +208,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
(BartConfig, TFBartForConditionalGeneration),
|
||||
(CamembertConfig, TFCamembertForMaskedLM),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||
(LongformerConfig, TFLongformerForMaskedLM),
|
||||
@@ -256,7 +259,9 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict([(T5Config, TFT5ForConditionalGeneration)])
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
|
||||
1190
src/transformers/modeling_tf_bart.py
Normal file
1190
src/transformers/modeling_tf_bart.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -717,7 +717,7 @@ BERT_START_DOCSTRING = r"""
|
||||
Args:
|
||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
|
||||
@@ -229,7 +229,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
|
||||
@@ -383,7 +383,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import copy
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -594,7 +595,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
output_hidden_states=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Tuple:
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
input_ids = inputs[0]
|
||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||
@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# positions we want to attend and -1e9 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
@@ -721,7 +722,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
if num_dims_encoder_attention_mask == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||
@@ -1417,7 +1418,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
def _reorder_cache(self, past, beam_idx) -> Tuple:
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
|
||||
|
||||
@@ -136,8 +136,7 @@ class TFCausalLanguageModelingLoss:
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||
)
|
||||
# make sure only labels that are not equal to -100
|
||||
# are taken into account as loss
|
||||
# make sure only labels that are not equal to -100 do not affect loss
|
||||
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||
|
||||
@@ -1945,11 +1945,6 @@ class SummarizationPipeline(Pipeline):
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
assert len(documents) > 0, "Please provide a document to summarize"
|
||||
|
||||
if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
|
||||
raise NotImplementedError(
|
||||
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
|
||||
)
|
||||
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
|
||||
if isinstance(documents[0], list):
|
||||
|
||||
@@ -212,6 +212,24 @@ class TFAutoModelWithLMHead:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFBartForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFBartModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
357
tests/test_modeling_tf_bart.py
Normal file
357
tests/test_modeling_tf_bart.py
Normal file
File diff suppressed because one or more lines are too long
@@ -302,7 +302,7 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
config.output_hidden_states = True
|
||||
@@ -472,10 +472,9 @@ class TFModelTesterMixin:
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
@@ -494,7 +493,9 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs_dict = model(inputs)
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_ids = inputs_keywords.pop("input_ids", None)
|
||||
@@ -507,28 +508,18 @@ class TFModelTesterMixin:
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
decoder_seq_length = (
|
||||
self.model_tester.decoder_seq_length
|
||||
if hasattr(self.model_tester, "decoder_seq_length")
|
||||
else self.model_tester.seq_length
|
||||
)
|
||||
encoder_seq_length = (
|
||||
self.model_tester.encoder_seq_length
|
||||
if hasattr(self.model_tester, "encoder_seq_length")
|
||||
else self.model_tester.seq_length
|
||||
)
|
||||
decoder_key_length = (
|
||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
|
||||
)
|
||||
encoder_key_length = (
|
||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
||||
)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["use_cache"] = False
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(model_inputs)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
@@ -279,8 +279,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ["t5-small"]:
|
||||
model = TFT5Model.from_pretrained(model_name)
|
||||
model = TFT5Model.from_pretrained("t5-small")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ FILL_MASK_FINETUNED_MODELS = ["sshleifer/tiny-distilroberta-base"]
|
||||
LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow
|
||||
|
||||
SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
||||
|
||||
TRANSLATION_FINETUNED_MODELS = [
|
||||
("patrickvonplaten/t5-tiny-random", "translation_en_to_de"),
|
||||
|
||||
Reference in New Issue
Block a user