Add TFBartForConditionalGeneration (#5411)
* half done * doc improvement * Cp test file * brokedn * broken test * undo some mess * ckpt * borked * Halfway * 6 passing * boom boom * Much progress but still 6 * boom boom * merged master * 10 passing * boom boom * Style * no t5 changes * 13 passing * Integration test failing, but not gibberish * Frustrated * Merged master * 4 fail * 4 fail * fix return_dict * boom boom * Still only 4 * prepare method * prepare method * before delete classif * Skip tests to avoid adding boilerplate * boom boom * fast tests passing * style * boom boom * Switch to supporting many input types * remove FIXMENORM * working * Fixed past_key_values/decoder_cached_states confusion * new broken test * Fix attention mask kwarg name * undo accidental * Style and reviewers * style * Docs and common tests * Cleaner assert messages * copy docs * style issues * Sphinx fix * Simplify caching logic * test does not require torch * copy _NoLayerEmbedTokens * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update tests/test_modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update src/transformers/modeling_tf_bart.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Line length and dont document None * Add pipeline test coverage * assert msg * At parity * Assert messages * mark slow * Update compile test * back in init * Merge master * Fix tests Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -86,3 +86,18 @@ BartForQuestionAnswering
|
|||||||
|
|
||||||
.. autoclass:: transformers.BartForQuestionAnswering
|
.. autoclass:: transformers.BartForQuestionAnswering
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TFBartModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFBartModel
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
|
TFBartForConditionalGeneration
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFBartForConditionalGeneration
|
||||||
|
:members: call
|
||||||
|
|||||||
@@ -652,6 +652,7 @@ if is_tf_available():
|
|||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
|
||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFBertEmbeddings,
|
TFBertEmbeddings,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import os
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@@ -37,6 +38,7 @@ from transformers import (
|
|||||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
|
BartConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
CamembertConfig,
|
CamembertConfig,
|
||||||
CTRLConfig,
|
CTRLConfig,
|
||||||
@@ -49,6 +51,7 @@ from transformers import (
|
|||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
T5Config,
|
T5Config,
|
||||||
TFAlbertForPreTraining,
|
TFAlbertForPreTraining,
|
||||||
|
TFBartForConditionalGeneration,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
@@ -87,6 +90,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AlbertForPreTraining,
|
AlbertForPreTraining,
|
||||||
|
BartForConditionalGeneration,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
@@ -113,6 +117,12 @@ if is_torch_available():
|
|||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
|
"bart": (
|
||||||
|
BartConfig,
|
||||||
|
TFBartForConditionalGeneration,
|
||||||
|
BartForConditionalGeneration,
|
||||||
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
),
|
||||||
"bert": (
|
"bert": (
|
||||||
BertConfig,
|
BertConfig,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
|
|||||||
@@ -640,6 +640,10 @@ class TFGenerationMixin:
|
|||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder and do_sample is False:
|
||||||
|
next_token_logits = self.adjust_logits_during_generation(
|
||||||
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||||
|
)
|
||||||
# calculate log softmax score
|
# calculate log softmax score
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
@@ -890,6 +894,13 @@ class TFGenerationMixin:
|
|||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
|
||||||
|
|
||||||
|
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||||
|
"""
|
||||||
|
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||||
|
the generate method.
|
||||||
|
"""
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||||
# create logit penalties for already seen input_ids
|
# create logit penalties for already seen input_ids
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a
|
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a
|
||||||
sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
|
sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
|
||||||
the decoder.
|
the decoder.
|
||||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
If :obj:`past_key_values` are used, the user can optionally input only the last
|
If :obj:`past_key_values` are used, the user can optionally input only the last
|
||||||
@@ -217,12 +217,6 @@ def _make_linear_from_emb(emb):
|
|||||||
return lin_layer
|
return lin_layer
|
||||||
|
|
||||||
|
|
||||||
# Helper Functions, mostly for making masks
|
|
||||||
def _check_shapes(shape_1, shape2):
|
|
||||||
if shape_1 != shape2:
|
|
||||||
raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))
|
|
||||||
|
|
||||||
|
|
||||||
def shift_tokens_right(input_ids, pad_token_id):
|
def shift_tokens_right(input_ids, pad_token_id):
|
||||||
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
||||||
prev_output_tokens = input_ids.clone()
|
prev_output_tokens = input_ids.clone()
|
||||||
@@ -595,7 +589,7 @@ class BartDecoder(nn.Module):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
next_decoder_cache = []
|
next_decoder_cache: List[Dict] = []
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -640,7 +634,7 @@ class BartDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reorder_buffer(attn_cache, new_order):
|
def _reorder_buffer(attn_cache: Dict, new_order) -> Dict:
|
||||||
for k, input_buffer_k in attn_cache.items():
|
for k, input_buffer_k in attn_cache.items():
|
||||||
if input_buffer_k is not None:
|
if input_buffer_k is not None:
|
||||||
attn_cache[k] = input_buffer_k.index_select(0, new_order)
|
attn_cache[k] = input_buffer_k.index_select(0, new_order)
|
||||||
@@ -679,17 +673,15 @@ class Attention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query,
|
query,
|
||||||
key: Optional[Tensor],
|
key: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
layer_state: Optional[Dict[str, Tensor]] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||||
static_kv: bool = self.encoder_decoder_attention
|
static_kv: bool = self.encoder_decoder_attention
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
assert embed_dim == self.embed_dim
|
|
||||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
|
||||||
# get here for encoder decoder cause of static_kv
|
# get here for encoder decoder cause of static_kv
|
||||||
if layer_state is not None: # reuse k,v and encoder_padding_mask
|
if layer_state is not None: # reuse k,v and encoder_padding_mask
|
||||||
saved_state = layer_state.get(self.cache_key, {})
|
saved_state = layer_state.get(self.cache_key, {})
|
||||||
@@ -697,17 +689,16 @@ class Attention(nn.Module):
|
|||||||
# previous time steps are cached - no need to recompute key and value if they are static
|
# previous time steps are cached - no need to recompute key and value if they are static
|
||||||
key = None
|
key = None
|
||||||
else:
|
else:
|
||||||
|
# this branch is hit by encoder
|
||||||
saved_state = None
|
saved_state = None
|
||||||
layer_state = {}
|
|
||||||
|
|
||||||
q = self.q_proj(query) * self.scaling
|
q = self.q_proj(query) * self.scaling
|
||||||
if static_kv:
|
if static_kv and key is None: # cross-attention with cache
|
||||||
if key is None:
|
|
||||||
k = v = None
|
k = v = None
|
||||||
else:
|
elif static_kv and key is not None: # cross-attention no prev_key found in cache
|
||||||
k = self.k_proj(key)
|
k = self.k_proj(key)
|
||||||
v = self.v_proj(key)
|
v = self.v_proj(key)
|
||||||
else:
|
else: # self-attention
|
||||||
k = self.k_proj(query)
|
k = self.k_proj(query)
|
||||||
v = self.v_proj(query)
|
v = self.v_proj(query)
|
||||||
|
|
||||||
@@ -717,18 +708,16 @@ class Attention(nn.Module):
|
|||||||
if v is not None:
|
if v is not None:
|
||||||
v = self._shape(v, -1, bsz)
|
v = self._shape(v, -1, bsz)
|
||||||
|
|
||||||
if saved_state is not None:
|
if saved_state:
|
||||||
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
|
||||||
|
|
||||||
# Update cache
|
# Update cache
|
||||||
layer_state[self.cache_key] = {
|
if isinstance(layer_state, dict):
|
||||||
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
|
||||||
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
layer_state[self.cache_key] = dict(prev_key=k.view(*cached_shape), prev_value=v.view(*cached_shape))
|
||||||
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert k is not None
|
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
assert key_padding_mask is None or key_padding_mask.shape == (bsz, src_len)
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
@@ -736,13 +725,7 @@ class Attention(nn.Module):
|
|||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
# Note: deleted workaround to get around fork/join parallelism not supporting Optional types. on 2020/10/15
|
||||||
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
|
||||||
key_padding_mask = None
|
|
||||||
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
|
|
||||||
bsz,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_padding_mask is not None: # don't attend to padding symbols
|
if key_padding_mask is not None: # don't attend to padding symbols
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
@@ -750,11 +733,7 @@ class Attention(nn.Module):
|
|||||||
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
attn_probs = F.dropout(
|
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
attn_weights,
|
|
||||||
p=self.dropout,
|
|
||||||
training=self.training,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert v is not None
|
assert v is not None
|
||||||
attn_output = torch.bmm(attn_probs, v)
|
attn_output = torch.bmm(attn_probs, v)
|
||||||
@@ -767,36 +746,13 @@ class Attention(nn.Module):
|
|||||||
attn_weights = None
|
attn_weights = None
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[Tensor]:
|
||||||
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||||
if "prev_key" in saved_state:
|
prev_K = saved_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
_prev_key = saved_state["prev_key"]
|
prev_V = saved_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
assert _prev_key is not None
|
new_K = prev_K if static_kv else torch.cat([prev_K, k], dim=1)
|
||||||
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
new_V = prev_V if static_kv else torch.cat([prev_V, v], dim=1)
|
||||||
if static_kv:
|
return new_K, new_V
|
||||||
k = prev_key
|
|
||||||
else:
|
|
||||||
assert k is not None
|
|
||||||
k = torch.cat([prev_key, k], dim=1)
|
|
||||||
if "prev_value" in saved_state:
|
|
||||||
_prev_value = saved_state["prev_value"]
|
|
||||||
assert _prev_value is not None
|
|
||||||
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
|
||||||
if static_kv:
|
|
||||||
v = prev_value
|
|
||||||
else:
|
|
||||||
assert v is not None
|
|
||||||
v = torch.cat([prev_value, v], dim=1)
|
|
||||||
assert k is not None and v is not None
|
|
||||||
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
|
|
||||||
if prev_key_padding_mask is not None:
|
|
||||||
if static_kv:
|
|
||||||
new_key_padding_mask = prev_key_padding_mask
|
|
||||||
else:
|
|
||||||
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
|
|
||||||
else:
|
|
||||||
new_key_padding_mask = key_padding_mask
|
|
||||||
return k, v, new_key_padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
class BartClassificationHead(nn.Module):
|
class BartClassificationHead(nn.Module):
|
||||||
@@ -1143,14 +1099,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
|
|
||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||||
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
||||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def _force_token_ids_generation(self, scores, token_id) -> None:
|
@staticmethod
|
||||||
|
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||||
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
|
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
|
|||||||
@@ -52,5 +52,5 @@ class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
|
|||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16
|
logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -51,5 +51,5 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from collections import OrderedDict
|
|||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
BartConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
CamembertConfig,
|
CamembertConfig,
|
||||||
CTRLConfig,
|
CTRLConfig,
|
||||||
@@ -51,6 +52,7 @@ from .modeling_tf_albert import (
|
|||||||
TFAlbertForTokenClassification,
|
TFAlbertForTokenClassification,
|
||||||
TFAlbertModel,
|
TFAlbertModel,
|
||||||
)
|
)
|
||||||
|
from .modeling_tf_bart import TFBartForConditionalGeneration
|
||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
TFBertForMultipleChoice,
|
TFBertForMultipleChoice,
|
||||||
@@ -206,6 +208,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||||||
(T5Config, TFT5ForConditionalGeneration),
|
(T5Config, TFT5ForConditionalGeneration),
|
||||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||||
(AlbertConfig, TFAlbertForMaskedLM),
|
(AlbertConfig, TFAlbertForMaskedLM),
|
||||||
|
(BartConfig, TFBartForConditionalGeneration),
|
||||||
(CamembertConfig, TFCamembertForMaskedLM),
|
(CamembertConfig, TFCamembertForMaskedLM),
|
||||||
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||||
(LongformerConfig, TFLongformerForMaskedLM),
|
(LongformerConfig, TFLongformerForMaskedLM),
|
||||||
@@ -256,7 +259,9 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict([(T5Config, TFT5ForConditionalGeneration)])
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||||
|
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)]
|
||||||
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
|||||||
1190
src/transformers/modeling_tf_bart.py
Normal file
1190
src/transformers/modeling_tf_bart.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -717,7 +717,7 @@ BERT_START_DOCSTRING = r"""
|
|||||||
Args:
|
Args:
|
||||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BERT_INPUTS_DOCSTRING = r"""
|
BERT_INPUTS_DOCSTRING = r"""
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
|
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
|
||||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -383,7 +383,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
|||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
|
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
|
||||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
|
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import copy
|
|||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@@ -594,7 +595,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
training=False,
|
training=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Tuple:
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -1e9 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
|
||||||
@@ -721,7 +722,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
if num_dims_encoder_attention_mask == 2:
|
if num_dims_encoder_attention_mask == 2:
|
||||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||||
|
|
||||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
|
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||||
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||||
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||||
@@ -1417,7 +1418,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx) -> Tuple:
|
||||||
# if decoder past is not included in output
|
# if decoder past is not included in output
|
||||||
# speedy decoding is disabled and no need to reorder
|
# speedy decoding is disabled and no need to reorder
|
||||||
|
|
||||||
|
|||||||
@@ -136,8 +136,7 @@ class TFCausalLanguageModelingLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100
|
# make sure only labels that are not equal to -100 do not affect loss
|
||||||
# are taken into account as loss
|
|
||||||
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|||||||
@@ -1945,11 +1945,6 @@ class SummarizationPipeline(Pipeline):
|
|||||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||||
assert len(documents) > 0, "Please provide a document to summarize"
|
assert len(documents) > 0, "Please provide a document to summarize"
|
||||||
|
|
||||||
if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
|
|
||||||
)
|
|
||||||
|
|
||||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||||
|
|
||||||
if isinstance(documents[0], list):
|
if isinstance(documents[0], list):
|
||||||
|
|||||||
@@ -212,6 +212,24 @@ class TFAutoModelWithLMHead:
|
|||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBartForConditionalGeneration:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBartModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
357
tests/test_modeling_tf_bart.py
Normal file
357
tests/test_modeling_tf_bart.py
Normal file
File diff suppressed because one or more lines are too long
@@ -302,7 +302,7 @@ class TFModelTesterMixin:
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
|
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
|
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
@@ -472,10 +472,9 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
# Prepare our model
|
# Prepare our model
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||||
# Let's load it from the disk to be sure we can use pretrained weights
|
# Let's load it from the disk to be sure we can use pretrained weights
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model
|
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
@@ -494,7 +493,9 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
outputs_dict = model(inputs)
|
||||||
|
|
||||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
input_ids = inputs_keywords.pop("input_ids", None)
|
input_ids = inputs_keywords.pop("input_ids", None)
|
||||||
@@ -507,28 +508,18 @@ class TFModelTesterMixin:
|
|||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
decoder_seq_length = (
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
|
||||||
self.model_tester.decoder_seq_length
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||||
if hasattr(self.model_tester, "decoder_seq_length")
|
decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
|
||||||
else self.model_tester.seq_length
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||||
)
|
|
||||||
encoder_seq_length = (
|
|
||||||
self.model_tester.encoder_seq_length
|
|
||||||
if hasattr(self.model_tester, "encoder_seq_length")
|
|
||||||
else self.model_tester.seq_length
|
|
||||||
)
|
|
||||||
decoder_key_length = (
|
|
||||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
|
|
||||||
)
|
|
||||||
encoder_key_length = (
|
|
||||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
inputs_dict["output_attentions"] = True
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["use_cache"] = False
|
||||||
config.output_hidden_states = False
|
config.output_hidden_states = False
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
outputs = model(model_inputs)
|
||||||
attentions = [t.numpy() for t in outputs[-1]]
|
attentions = [t.numpy() for t in outputs[-1]]
|
||||||
self.assertEqual(model.config.output_hidden_states, False)
|
self.assertEqual(model.config.output_hidden_states, False)
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|||||||
@@ -279,8 +279,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in ["t5-small"]:
|
model = TFT5Model.from_pretrained("t5-small")
|
||||||
model = TFT5Model.from_pretrained(model_name)
|
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ FILL_MASK_FINETUNED_MODELS = ["sshleifer/tiny-distilroberta-base"]
|
|||||||
LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow
|
LARGE_FILL_MASK_FINETUNED_MODELS = ["distilroberta-base"] # @slow
|
||||||
|
|
||||||
SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
||||||
TF_SUMMARIZATION_FINETUNED_MODELS = ["patrickvonplaten/t5-tiny-random"]
|
TF_SUMMARIZATION_FINETUNED_MODELS = ["sshleifer/bart-tiny-random", "patrickvonplaten/t5-tiny-random"]
|
||||||
|
|
||||||
TRANSLATION_FINETUNED_MODELS = [
|
TRANSLATION_FINETUNED_MODELS = [
|
||||||
("patrickvonplaten/t5-tiny-random", "translation_en_to_de"),
|
("patrickvonplaten/t5-tiny-random", "translation_en_to_de"),
|
||||||
|
|||||||
Reference in New Issue
Block a user