[Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile (#5395)
* add first version of clm tf * make style * add more tests for bert * update tf clm loss * fix tests * correct tf ner script * add mlm loss * delete bogus file * clean tf auto model + add tests * finish adding clm loss everywhere * fix training in distilbert * fix flake8 * save intermediate * fix tf t5 naming * remove prints * finish up * up * fix tf gpt2 * fix new test utils import * fix flake8 * keep backward compatibility * Update src/transformers/modeling_tf_albert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_roberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_mobilebert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_distilbert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply sylvains suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
33e43edddc
commit
4dc65591b5
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -184,7 +185,12 @@ def main():
|
|||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for j in range(seq_len):
|
for j in range(seq_len):
|
||||||
if label_ids[i, j] != -1:
|
if label_ids[i, j] == -1:
|
||||||
|
label_ids[i, j] = -100
|
||||||
|
warnings.warn(
|
||||||
|
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
|
||||||
|
)
|
||||||
|
if label_ids[i, j] != -100:
|
||||||
out_label_list[i].append(label_map[label_ids[i][j]])
|
out_label_list[i].append(label_map[label_ids[i][j]])
|
||||||
preds_list[i].append(label_map[preds[i][j]])
|
preds_list[i].append(label_map[preds[i][j]])
|
||||||
|
|
||||||
|
|||||||
@@ -453,6 +453,9 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForMultipleChoice,
|
TFAutoModelForMultipleChoice,
|
||||||
TFAutoModelForPreTraining,
|
TFAutoModelForPreTraining,
|
||||||
@@ -460,6 +463,9 @@ if is_tf_available():
|
|||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFAutoModelForTokenClassification,
|
TFAutoModelForTokenClassification,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
|
TFAutoModelForMaskedLM,
|
||||||
|
TFAutoModelForSeq2SeqLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_tf_albert import (
|
from .modeling_tf_albert import (
|
||||||
@@ -478,6 +484,7 @@ if is_tf_available():
|
|||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFBertEmbeddings,
|
TFBertEmbeddings,
|
||||||
|
TFBertLMHeadModel,
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
TFBertForMultipleChoice,
|
TFBertForMultipleChoice,
|
||||||
TFBertForNextSentencePrediction,
|
TFBertForNextSentencePrediction,
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ from .modeling_bert import (
|
|||||||
from .modeling_camembert import (
|
from .modeling_camembert import (
|
||||||
CamembertForMaskedLM,
|
CamembertForMaskedLM,
|
||||||
CamembertForMultipleChoice,
|
CamembertForMultipleChoice,
|
||||||
|
CamembertForQuestionAnswering,
|
||||||
CamembertForSequenceClassification,
|
CamembertForSequenceClassification,
|
||||||
CamembertForTokenClassification,
|
CamembertForTokenClassification,
|
||||||
CamembertModel,
|
CamembertModel,
|
||||||
@@ -306,6 +307,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||||||
[
|
[
|
||||||
(DistilBertConfig, DistilBertForQuestionAnswering),
|
(DistilBertConfig, DistilBertForQuestionAnswering),
|
||||||
(AlbertConfig, AlbertForQuestionAnswering),
|
(AlbertConfig, AlbertForQuestionAnswering),
|
||||||
|
(CamembertConfig, CamembertForQuestionAnswering),
|
||||||
(BartConfig, BartForQuestionAnswering),
|
(BartConfig, BartForQuestionAnswering),
|
||||||
(LongformerConfig, LongformerForQuestionAnswering),
|
(LongformerConfig, LongformerForQuestionAnswering),
|
||||||
(XLMRobertaConfig, XLMRobertaForQuestionAnswering),
|
(XLMRobertaConfig, XLMRobertaForQuestionAnswering),
|
||||||
@@ -336,7 +338,6 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(CamembertConfig, CamembertForMultipleChoice),
|
(CamembertConfig, CamembertForMultipleChoice),
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .file_utils import (
|
|||||||
)
|
)
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
@@ -822,7 +823,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
||||||
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@@ -834,8 +835,26 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj::obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||||
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`
|
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`
|
||||||
@@ -852,14 +871,35 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
outputs = self.albert(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
outputs = self.albert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.predictions(sequence_output, training=kwargs.get("training", False))
|
prediction_scores = self.predictions(sequence_output, training=training)
|
||||||
|
|
||||||
# Add hidden states and attention if they are here
|
# Add hidden states and attention if they are here
|
||||||
outputs = (prediction_scores,) + outputs[2:]
|
outputs = (prediction_scores,) + outputs[2:]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, prediction_scores)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -29,6 +29,8 @@ from .file_utils import (
|
|||||||
add_start_docstrings_to_callable,
|
add_start_docstrings_to_callable,
|
||||||
)
|
)
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
@@ -803,9 +805,12 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||||
class TFBertForMaskedLM(TFBertPreTrainedModel):
|
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
assert (
|
||||||
|
not config.is_decoder
|
||||||
|
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
||||||
|
|
||||||
self.bert = TFBertMainLayer(config, name="bert")
|
self.bert = TFBertMainLayer(config, name="bert")
|
||||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||||
@@ -815,8 +820,26 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-cased")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-cased")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -833,13 +856,113 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
outputs = self.bert(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
prediction_scores = self.mlm(sequence_output, training=training)
|
||||||
|
|
||||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, prediction_scores)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
assert config.is_decoder, "If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`"
|
||||||
|
|
||||||
|
self.bert = TFBertMainLayer(config, name="bert")
|
||||||
|
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.bert.embeddings
|
||||||
|
|
||||||
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-cased")
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the cross entropy classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
logits = self.mlm(sequence_output, training=training)
|
||||||
|
|
||||||
|
outputs = (logits,) + outputs[2:] # Add hidden states and attention if they are here
|
||||||
|
if labels is not None:
|
||||||
|
# shift labels to the left and cut last logit token
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
labels = labels[:, 1:]
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
return outputs # prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
cast_bool_to_primitive,
|
||||||
@@ -542,7 +543,7 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
|
|||||||
(linear layer with weights tied to the input embeddings). """,
|
(linear layer with weights tied to the input embeddings). """,
|
||||||
CTRL_START_DOCSTRING,
|
CTRL_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
||||||
@@ -561,8 +562,26 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the cross entropy classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
|
||||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -583,11 +602,37 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[10] if len(inputs) > 10 else labels
|
||||||
|
if len(inputs) > 10:
|
||||||
|
inputs = inputs[:10]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
inputs,
|
||||||
|
past=past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
outputs = (lm_logits,) + transformer_outputs[1:]
|
outputs = (logits,) + transformer_outputs[1:]
|
||||||
|
if labels is not None:
|
||||||
|
# shift labels to the left and cut last logit token
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
labels = labels[:, 1:]
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # lm_logits, presents, (all hidden_states), (attentions)
|
return outputs # lm_logits, presents, (all hidden_states), (attentions)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .file_utils import (
|
|||||||
add_start_docstrings_to_callable,
|
add_start_docstrings_to_callable,
|
||||||
)
|
)
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
@@ -116,7 +117,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
|
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
|
||||||
mode: string, a valid value is one of "embedding" and "linear".
|
mode: string, a valid value is one of "embedding" and "linear".
|
||||||
Returns:
|
Returns:
|
||||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||||
@@ -528,9 +529,9 @@ DISTILBERT_START_DOCSTRING = r"""
|
|||||||
|
|
||||||
- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
|
- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
|
||||||
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
||||||
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
|
:obj:`model([input_ids, attention_mask])`
|
||||||
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
||||||
:obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
:obj:`model({'input_ids': input_ids})`
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
|
||||||
@@ -626,7 +627,7 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
|
|||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING,
|
"""DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
@@ -644,8 +645,23 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
||||||
@@ -663,7 +679,22 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
distilbert_output = self.distilbert(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[6] if len(inputs) > 6 else labels
|
||||||
|
if len(inputs) > 6:
|
||||||
|
inputs = inputs[:6]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
distilbert_output = self.distilbert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = distilbert_output[0] # (bs, seq_length, dim)
|
hidden_states = distilbert_output[0] # (bs, seq_length, dim)
|
||||||
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
||||||
@@ -672,6 +703,11 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
|||||||
prediction_logits = self.vocab_projector(prediction_logits)
|
prediction_logits = self.vocab_projector(prediction_logits)
|
||||||
|
|
||||||
outputs = (prediction_logits,) + distilbert_output[1:]
|
outputs = (prediction_logits,) + distilbert_output[1:]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, prediction_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # logits, (hidden_states), (attentions)
|
return outputs # logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from transformers import ElectraConfig
|
|||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFMaskedLanguageModelingLoss,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
@@ -506,7 +507,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer):
|
|||||||
the only model of the two to have been trained for the masked language modeling task.""",
|
the only model of the two to have been trained for the masked language modeling task.""",
|
||||||
ELECTRA_START_DOCSTRING,
|
ELECTRA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFElectraForMaskedLM(TFElectraPreTrainedModel):
|
class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
|
||||||
@@ -534,9 +535,16 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||||
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -553,6 +561,12 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(input_ids, (tuple, list)):
|
||||||
|
labels = input_ids[8] if len(input_ids) > 8 else labels
|
||||||
|
if len(input_ids) > 8:
|
||||||
|
input_ids = input_ids[:8]
|
||||||
|
elif isinstance(input_ids, (dict, BatchEncoding)):
|
||||||
|
labels = input_ids.pop("labels", labels)
|
||||||
|
|
||||||
generator_hidden_states = self.electra(
|
generator_hidden_states = self.electra(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -571,6 +585,10 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
|
|||||||
output = (prediction_scores,)
|
output = (prediction_scores,)
|
||||||
output += generator_hidden_states[1:]
|
output += generator_hidden_states[1:]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, prediction_scores)
|
||||||
|
output = (loss,) + output
|
||||||
|
|
||||||
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_gpt2 import GPT2Config
|
from .configuration_gpt2 import GPT2Config
|
||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFConv1D,
|
TFConv1D,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
@@ -272,8 +273,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||||
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||||
output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
|
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
|
||||||
output_hidden_states = inputs[9] if len(inputs) > 8 else output_hidden_states
|
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
|
||||||
assert len(inputs) <= 10, "Too many inputs."
|
assert len(inputs) <= 10, "Too many inputs."
|
||||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
@@ -524,7 +525,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
|
|||||||
(linear layer with weights tied to the input embeddings). """,
|
(linear layer with weights tied to the input embeddings). """,
|
||||||
GPT2_START_DOCSTRING,
|
GPT2_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFGPT2MainLayer(config, name="transformer")
|
self.transformer = TFGPT2MainLayer(config, name="transformer")
|
||||||
@@ -541,8 +542,26 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the cross entropy classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
|
||||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -563,12 +582,38 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[10] if len(inputs) > 10 else labels
|
||||||
|
if len(inputs) > 10:
|
||||||
|
inputs = inputs[:10]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
inputs,
|
||||||
|
past=past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
logits = self.transformer.wte(hidden_states, mode="linear")
|
||||||
|
|
||||||
outputs = (lm_logits,) + transformer_outputs[1:]
|
outputs = (logits,) + transformer_outputs[1:]
|
||||||
|
if labels is not None:
|
||||||
|
# shift labels to the left and cut last logit token
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
labels = labels[:, 1:]
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # lm_logits, presents, (all hidden_states), (attentions)
|
return outputs # lm_logits, presents, (all hidden_states), (attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .file_utils import (
|
|||||||
)
|
)
|
||||||
from .modeling_tf_bert import TFBertIntermediate, gelu, gelu_new, swish
|
from .modeling_tf_bert import TFBertIntermediate, gelu, gelu_new, swish
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
@@ -929,7 +930,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
||||||
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel):
|
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@@ -941,8 +942,25 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
|
||||||
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -959,14 +977,34 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
outputs = self.mobilebert(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
outputs = self.mobilebert(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
|
prediction_scores = self.mlm(sequence_output, training=training)
|
||||||
|
|
||||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, prediction_scores)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
return outputs # (loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
|
class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_openai import OpenAIGPTConfig
|
from .configuration_openai import OpenAIGPTConfig
|
||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFConv1D,
|
TFConv1D,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
@@ -479,7 +480,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
(linear layer with weights tied to the input embeddings). """,
|
(linear layer with weights tied to the input embeddings). """,
|
||||||
OPENAI_GPT_START_DOCSTRING,
|
OPENAI_GPT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
|
class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
||||||
@@ -489,8 +490,24 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="openai-gpt")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="openai-gpt")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the cross entropy classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
|
||||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -507,12 +524,35 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||||
|
outputs = (logits,) + transformer_outputs[1:]
|
||||||
|
|
||||||
outputs = (lm_logits,) + transformer_outputs[1:]
|
if labels is not None:
|
||||||
|
# shift labels to the left and cut last logit token
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
labels = labels[:, 1:]
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # lm_logits, (all hidden_states), (attentions)
|
return outputs # lm_logits, (all hidden_states), (attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .file_utils import (
|
|||||||
)
|
)
|
||||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
|
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
@@ -264,7 +265,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||||
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
|
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@@ -276,8 +277,26 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the masked language modeling loss.
|
||||||
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||||
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||||
|
in ``[0, ..., config.vocab_size]``
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||||
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -294,14 +313,37 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
|
|||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
outputs = self.roberta(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[8] if len(inputs) > 8 else labels
|
||||||
|
if len(inputs) > 8:
|
||||||
|
inputs = inputs[:8]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
outputs = self.roberta(
|
||||||
|
inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
prediction_scores = self.lm_head(sequence_output)
|
prediction_scores = self.lm_head(sequence_output)
|
||||||
|
|
||||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # prediction_scores, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, prediction_scores)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ import copy
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
cast_bool_to_primitive,
|
||||||
@@ -111,6 +113,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.layer_id = next(TFT5Attention.NEW_ID)
|
self.layer_id = next(TFT5Attention.NEW_ID)
|
||||||
self.is_decoder = config.is_decoder
|
self.is_decoder = config.is_decoder
|
||||||
|
self.use_cache = config.use_cache
|
||||||
self.has_relative_attention_bias = has_relative_attention_bias
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
|
||||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||||
@@ -258,9 +261,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
k, v = past_key_value_state
|
k, v = past_key_value_state
|
||||||
|
|
||||||
# to cope with keras serialization
|
# to cope with keras serialization
|
||||||
use_cache = cast_bool_to_primitive(use_cache)
|
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
|
||||||
|
|
||||||
if self.is_decoder and use_cache is True:
|
|
||||||
present_key_value_state = ((k, v),)
|
present_key_value_state = ((k, v),)
|
||||||
else:
|
else:
|
||||||
present_key_value_state = (None,)
|
present_key_value_state = (None,)
|
||||||
@@ -295,7 +296,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
outputs = (context,) + present_key_value_state
|
outputs = (context,) + present_key_value_state
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if cast_bool_to_primitive(output_attentions, True) is True:
|
||||||
outputs = outputs + (weights,)
|
outputs = outputs + (weights,)
|
||||||
if self.has_relative_attention_bias:
|
if self.has_relative_attention_bias:
|
||||||
outputs = outputs + (position_bias,)
|
outputs = outputs + (position_bias,)
|
||||||
@@ -572,18 +573,22 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
||||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
|
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
|
||||||
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
|
use_cache = inputs[7] if len(inputs) > 7 else use_cache
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
|
||||||
|
output_hidden_states = inputs[9] if len(inputs) > 8 else output_hidden_states
|
||||||
|
assert len(inputs) <= 10, "Too many inputs."
|
||||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("decoder_input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("decoder_attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
|
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
|
||||||
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
||||||
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
head_mask = inputs.get("head_mask", head_mask)
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
|
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
|
||||||
|
use_cache = inputs.get("use_cache", use_cache)
|
||||||
output_attentions = inputs.get("output_attentions", output_attentions)
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
assert len(inputs) <= 8, "Too many inputs."
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
|
assert len(inputs) <= 10, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
input_ids = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
@@ -733,8 +738,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if use_cache is True:
|
# need to check if is decoder here as well for special cases when using keras compile
|
||||||
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
|
if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
|
||||||
outputs = outputs + (present_key_value_states,)
|
outputs = outputs + (present_key_value_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if cast_bool_to_primitive(output_hidden_states) is True:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
@@ -763,12 +768,38 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
|||||||
inputs = tf.constant(DUMMY_INPUTS)
|
inputs = tf.constant(DUMMY_INPUTS)
|
||||||
input_mask = tf.constant(DUMMY_MASK)
|
input_mask = tf.constant(DUMMY_MASK)
|
||||||
dummy_inputs = {
|
dummy_inputs = {
|
||||||
"inputs": inputs,
|
"input_ids": inputs,
|
||||||
"decoder_input_ids": inputs,
|
"decoder_input_ids": inputs,
|
||||||
"decoder_attention_mask": input_mask,
|
"decoder_attention_mask": input_mask,
|
||||||
}
|
}
|
||||||
return dummy_inputs
|
return dummy_inputs
|
||||||
|
|
||||||
|
def _shift_right(self, input_ids):
|
||||||
|
decoder_start_token_id = self.config.decoder_start_token_id
|
||||||
|
pad_token_id = self.config.pad_token_id
|
||||||
|
|
||||||
|
assert (
|
||||||
|
decoder_start_token_id is not None
|
||||||
|
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||||
|
|
||||||
|
# shift inputs to the right
|
||||||
|
shifted_input_ids = tf.zeros_like(input_ids, dtype=tf.int32)
|
||||||
|
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||||
|
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||||
|
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||||
|
|
||||||
|
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||||
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
|
shifted_input_ids = tf.where(
|
||||||
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tf.math.reduce_any(
|
||||||
|
shifted_input_ids >= 0
|
||||||
|
).numpy(), "Verify that `labels` has only positive values and -100"
|
||||||
|
|
||||||
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|
||||||
T5_START_DOCSTRING = r"""
|
T5_START_DOCSTRING = r"""
|
||||||
The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
|
The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
|
||||||
@@ -900,7 +931,22 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_past_key_value_states=None,
|
||||||
|
decoder_input_ids=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
decoder_inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
||||||
@@ -934,37 +980,58 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
if isinstance(inputs, dict):
|
input_ids = inputs[0]
|
||||||
kwargs.update(inputs)
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
||||||
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
|
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
|
||||||
|
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
||||||
|
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
||||||
|
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||||
|
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||||
|
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||||
|
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||||
|
assert len(inputs) <= 12, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
if "inputs" in inputs:
|
||||||
|
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
|
||||||
|
input_ids = inputs.get("inputs")
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
|
||||||
|
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||||
|
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||||
|
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||||
|
use_cache = inputs.get("use_cache", use_cache)
|
||||||
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
|
assert len(inputs) <= 12, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
kwargs["inputs"] = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
# retrieve arguments
|
|
||||||
inputs = kwargs.get("inputs", None)
|
|
||||||
inputs_embeds = kwargs.get("inputs_embeds", None)
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
encoder_outputs = kwargs.get("encoder_outputs", None)
|
|
||||||
decoder_input_ids = kwargs.get("decoder_input_ids", None)
|
|
||||||
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
|
|
||||||
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
|
|
||||||
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
|
|
||||||
use_cache = kwargs.get("use_cache", None)
|
|
||||||
head_mask = kwargs.get("head_mask", None)
|
|
||||||
output_attentions = kwargs.get("output_attentions", None)
|
|
||||||
output_hidden_states = kwargs.get("output_hidden_states", None)
|
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
# Encode if needed (training, first prediction pass)
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs,
|
[
|
||||||
attention_mask=attention_mask,
|
input_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
attention_mask,
|
||||||
head_mask=head_mask,
|
None,
|
||||||
output_attentions=output_attentions,
|
None,
|
||||||
output_hidden_states=output_hidden_states,
|
inputs_embeds,
|
||||||
|
head_mask,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
],
|
||||||
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = encoder_outputs[0]
|
hidden_states = encoder_outputs[0]
|
||||||
@@ -979,19 +1046,22 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
decoder_input_ids,
|
[
|
||||||
attention_mask=decoder_attention_mask,
|
decoder_input_ids,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
decoder_attention_mask,
|
||||||
past_key_value_states=decoder_past_key_value_states,
|
hidden_states,
|
||||||
encoder_hidden_states=hidden_states,
|
attention_mask,
|
||||||
encoder_attention_mask=attention_mask,
|
decoder_inputs_embeds,
|
||||||
head_mask=head_mask,
|
head_mask,
|
||||||
use_cache=use_cache,
|
decoder_past_key_value_states,
|
||||||
output_attentions=output_attentions,
|
use_cache,
|
||||||
output_hidden_states=output_hidden_states,
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
],
|
||||||
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cache is True:
|
if cast_bool_to_primitive(use_cache, self.config.use_cache) is True:
|
||||||
past = ((encoder_outputs, decoder_outputs[1]),)
|
past = ((encoder_outputs, decoder_outputs[1]),)
|
||||||
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
||||||
|
|
||||||
@@ -999,7 +1069,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||||
class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.model_dim = config.d_model
|
self.model_dim = config.d_model
|
||||||
@@ -1042,8 +1112,28 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
|||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_past_key_value_states=None,
|
||||||
|
decoder_input_ids=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
decoder_inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the cross entropy classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
|
||||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||||
@@ -1080,25 +1170,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
|||||||
>>> result = model.generate(inputs)
|
>>> result = model.generate(inputs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
if isinstance(inputs, dict):
|
input_ids = inputs[0]
|
||||||
kwargs.update(inputs)
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
|
||||||
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
|
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
|
||||||
|
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
|
||||||
|
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
|
||||||
|
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
|
||||||
|
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||||
|
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
|
||||||
|
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
|
||||||
|
labels = inputs[12] if len(inputs) > 12 else labels
|
||||||
|
assert len(inputs) <= 13, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
if "inputs" in inputs:
|
||||||
|
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
|
||||||
|
input_ids = inputs.get("inputs")
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
|
||||||
|
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
|
||||||
|
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
|
||||||
|
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
|
||||||
|
use_cache = inputs.get("use_cache", use_cache)
|
||||||
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
||||||
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
||||||
|
labels = inputs.get("labels", labels)
|
||||||
|
assert len(inputs) <= 13, "Too many inputs."
|
||||||
else:
|
else:
|
||||||
kwargs["inputs"] = inputs
|
input_ids = inputs
|
||||||
|
|
||||||
# retrieve arguments
|
|
||||||
inputs = kwargs.get("inputs", None)
|
|
||||||
decoder_input_ids = kwargs.get("decoder_input_ids", None)
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
encoder_outputs = kwargs.get("encoder_outputs", None)
|
|
||||||
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
|
|
||||||
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
|
|
||||||
use_cache = kwargs.get("use_cache", None)
|
|
||||||
inputs_embeds = kwargs.get("inputs_embeds", None)
|
|
||||||
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
|
|
||||||
head_mask = kwargs.get("head_mask", None)
|
|
||||||
output_attentions = kwargs.get("output_attentions", None)
|
|
||||||
output_hidden_states = kwargs.get("output_hidden_states", None)
|
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
@@ -1106,16 +1212,27 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
|||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
# Convert encoder inputs in embeddings if needed
|
# Convert encoder inputs in embeddings if needed
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
inputs,
|
[
|
||||||
attention_mask=attention_mask,
|
input_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
attention_mask,
|
||||||
head_mask=head_mask,
|
None,
|
||||||
output_attentions=output_attentions,
|
None,
|
||||||
output_hidden_states=output_hidden_states,
|
inputs_embeds,
|
||||||
|
head_mask,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
],
|
||||||
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = encoder_outputs[0]
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
|
# get decoder inputs from shifting lm labels to the right
|
||||||
|
decoder_input_ids = self._shift_right(labels)
|
||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_value_states is not None:
|
if decoder_past_key_value_states is not None:
|
||||||
@@ -1126,28 +1243,35 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
|||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decoder_outputs = self.decoder(
|
decoder_outputs = self.decoder(
|
||||||
decoder_input_ids,
|
[
|
||||||
attention_mask=decoder_attention_mask,
|
decoder_input_ids,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
decoder_attention_mask,
|
||||||
past_key_value_states=decoder_past_key_value_states,
|
hidden_states,
|
||||||
encoder_hidden_states=hidden_states,
|
attention_mask,
|
||||||
encoder_attention_mask=attention_mask,
|
decoder_inputs_embeds,
|
||||||
head_mask=head_mask,
|
head_mask,
|
||||||
use_cache=use_cache,
|
decoder_past_key_value_states,
|
||||||
output_attentions=output_attentions,
|
use_cache,
|
||||||
output_hidden_states=output_hidden_states,
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
],
|
||||||
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
# insert decoder past at right place
|
# insert decoder past at right place
|
||||||
# to speed up decoding
|
# to speed up decoding
|
||||||
if use_cache is True:
|
if cast_bool_to_primitive(use_cache, self.config.use_cache) is True:
|
||||||
past = ((encoder_outputs, decoder_outputs[1]),)
|
past = ((encoder_outputs, decoder_outputs[1]),)
|
||||||
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
||||||
|
|
||||||
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
|
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
|
||||||
embed_tokens = self.get_output_embeddings()
|
embed_tokens = self.get_output_embeddings()
|
||||||
lm_logits = embed_tokens(sequence_output, mode="linear")
|
logits = embed_tokens(sequence_output, mode="linear")
|
||||||
decoder_outputs = (lm_logits,) + decoder_outputs[1:]
|
decoder_outputs = (logits,) + decoder_outputs[1:]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
decoder_outputs = (loss,) + decoder_outputs
|
||||||
|
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -107,6 +108,19 @@ def keras_serializable(cls):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class TFCausalLanguageModelingLoss:
|
||||||
|
def compute_loss(self, labels, logits):
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
)
|
||||||
|
# make sure only labels that are not equal to -100
|
||||||
|
# are taken into account as loss
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||||
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
return loss_fn(labels, reduced_logits)
|
||||||
|
|
||||||
|
|
||||||
class TFQuestionAnsweringLoss:
|
class TFQuestionAnsweringLoss:
|
||||||
def compute_loss(self, labels, logits):
|
def compute_loss(self, labels, logits):
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
@@ -123,7 +137,13 @@ class TFTokenClassificationLoss:
|
|||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
)
|
)
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
# make sure only labels that are not equal to -100
|
||||||
|
# are taken into account as loss
|
||||||
|
if tf.math.reduce_any(labels == -1).numpy() is True:
|
||||||
|
warnings.warn("Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead.")
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||||
|
else:
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|
||||||
@@ -143,6 +163,7 @@ class TFSequenceClassificationLoss:
|
|||||||
|
|
||||||
|
|
||||||
TFMultipleChoiceLoss = TFSequenceClassificationLoss
|
TFMultipleChoiceLoss = TFSequenceClassificationLoss
|
||||||
|
TFMaskedLanguageModelingLoss = TFCausalLanguageModelingLoss
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .file_utils import (
|
|||||||
add_start_docstrings_to_callable,
|
add_start_docstrings_to_callable,
|
||||||
)
|
)
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
@@ -871,7 +872,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
|
|||||||
(linear layer with weights tied to the input embeddings). """,
|
(linear layer with weights tied to the input embeddings). """,
|
||||||
XLNET_START_DOCSTRING,
|
XLNET_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFXLNetMainLayer(config, name="transformer")
|
self.transformer = TFXLNetMainLayer(config, name="transformer")
|
||||||
@@ -912,8 +913,28 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
mems=None,
|
||||||
|
perm_mask=None,
|
||||||
|
target_mapping=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
input_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=True,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the cross entropy classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
||||||
prediction_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
@@ -957,12 +978,40 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
|||||||
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
|
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
labels = inputs[12] if len(inputs) > 12 else labels
|
||||||
|
if len(inputs) > 12:
|
||||||
|
inputs = inputs[:12]
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
labels = inputs.pop("labels", labels)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
mems=None,
|
||||||
|
perm_mask=None,
|
||||||
|
target_mapping=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
input_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=True,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
hidden_state = transformer_outputs[0]
|
hidden_state = transformer_outputs[0]
|
||||||
logits = self.lm_loss(hidden_state)
|
logits = self.lm_loss(hidden_state, training=training)
|
||||||
|
|
||||||
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# shift labels to the left and cut last logit token
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
labels = labels[:, 1:]
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs # return logits, (mems), (hidden states), (attentions)
|
return outputs # return logits, (mems), (hidden states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1041,9 +1041,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
labels=None,
|
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
labels=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
|||||||
DistilBertForTokenClassification,
|
DistilBertForTokenClassification,
|
||||||
DistilBertForQuestionAnswering,
|
DistilBertForQuestionAnswering,
|
||||||
DistilBertForSequenceClassification,
|
DistilBertForSequenceClassification,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
|
|
||||||
class DistilBertModelTester(object):
|
class DistilBertModelTester(object):
|
||||||
@@ -276,8 +277,8 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
|
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
# @slow
|
@slow
|
||||||
# def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
# for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
# model = DistilBertModel.from_pretrained(model_name)
|
model = DistilBertModel.from_pretrained(model_name)
|
||||||
# self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ if is_tf_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
GPT2Config,
|
||||||
|
T5Config,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFAutoModelForPreTraining,
|
TFAutoModelForPreTraining,
|
||||||
@@ -35,6 +37,25 @@ if is_tf_available():
|
|||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
|
TFGPT2LMHeadModel,
|
||||||
|
TFAutoModelForMaskedLM,
|
||||||
|
TFAutoModelForSeq2SeqLM,
|
||||||
|
TFT5ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.modeling_tf_gpt2 import TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.modeling_tf_auto import (
|
||||||
|
TF_MODEL_MAPPING,
|
||||||
|
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -72,10 +93,21 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, TFBertForPreTraining)
|
self.assertIsInstance(model, TFBertForPreTraining)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_for_causal_lm(self):
|
||||||
|
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, GPT2Config)
|
||||||
|
|
||||||
|
model = TFAutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
model, loading_info = TFAutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, TFGPT2LMHeadModel)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lmhead_model_from_pretrained(self):
|
def test_lmhead_model_from_pretrained(self):
|
||||||
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
for model_name in ["bert-base-uncased"]:
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(config)
|
self.assertIsNotNone(config)
|
||||||
self.assertIsInstance(config, BertConfig)
|
self.assertIsInstance(config, BertConfig)
|
||||||
@@ -84,6 +116,30 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_for_masked_lm(self):
|
||||||
|
for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = TFAutoModelForMaskedLM.from_pretrained(model_name)
|
||||||
|
model, loading_info = TFAutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_for_encoder_decoder_lm(self):
|
||||||
|
for model_name in TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, T5Config)
|
||||||
|
|
||||||
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||||
|
model, loading_info = TFAutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, TFT5ForConditionalGeneration)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sequence_classification_model_from_pretrained(self):
|
def test_sequence_classification_model_from_pretrained(self):
|
||||||
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
@@ -119,3 +175,28 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsInstance(model, TFRobertaForMaskedLM)
|
self.assertIsInstance(model, TFRobertaForMaskedLM)
|
||||||
self.assertEqual(model.num_parameters(), 14830)
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
|
||||||
|
def test_parents_and_children_in_mappings(self):
|
||||||
|
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||||
|
# by the parents and will return the wrong configuration type when using auto models
|
||||||
|
mappings = (
|
||||||
|
TF_MODEL_MAPPING,
|
||||||
|
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
)
|
||||||
|
|
||||||
|
for mapping in mappings:
|
||||||
|
mapping = tuple(mapping.items())
|
||||||
|
for index, (child_config, child_model) in enumerate(mapping[1:]):
|
||||||
|
for parent_config, parent_model in mapping[: index + 1]:
|
||||||
|
with self.subTest(
|
||||||
|
msg="Testing if {} is child of {}".format(child_config.__name__, parent_config.__name__)
|
||||||
|
):
|
||||||
|
self.assertFalse(issubclass(child_config, parent_config))
|
||||||
|
self.assertFalse(issubclass(child_model, parent_model))
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from transformers.modeling_tf_bert import (
|
from transformers.modeling_tf_bert import (
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
|
TFBertLMHeadModel,
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
TFBertForNextSentencePrediction,
|
TFBertForNextSentencePrediction,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
@@ -142,11 +143,30 @@ class TFBertModelTester:
|
|||||||
)
|
)
|
||||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
def create_and_check_bert_lm_head(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
model = TFBertLMHeadModel(config=config)
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
(prediction_scores,) = model(inputs)
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_bert_for_masked_lm(
|
def create_and_check_bert_for_masked_lm(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
model = TFBertForMaskedLM(config=config)
|
model = TFBertForMaskedLM(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
(prediction_scores,) = model(inputs)
|
(prediction_scores,) = model(inputs)
|
||||||
result = {
|
result = {
|
||||||
"prediction_scores": prediction_scores.numpy(),
|
"prediction_scores": prediction_scores.numpy(),
|
||||||
@@ -186,11 +206,14 @@ class TFBertModelTester:
|
|||||||
):
|
):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = TFBertForSequenceClassification(config=config)
|
model = TFBertForSequenceClassification(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
inputs = {
|
||||||
(logits,) = model(inputs)
|
"input_ids": input_ids,
|
||||||
result = {
|
"attention_mask": input_mask,
|
||||||
"logits": logits.numpy(),
|
"token_type_ids": token_type_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(logits,) = model(inputs)
|
||||||
|
result = {"logits": logits.numpy()}
|
||||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||||
|
|
||||||
def create_and_check_bert_for_multiple_choice(
|
def create_and_check_bert_for_multiple_choice(
|
||||||
@@ -207,9 +230,7 @@ class TFBertModelTester:
|
|||||||
"token_type_ids": multiple_choice_token_type_ids,
|
"token_type_ids": multiple_choice_token_type_ids,
|
||||||
}
|
}
|
||||||
(logits,) = model(inputs)
|
(logits,) = model(inputs)
|
||||||
result = {
|
result = {"logits": logits.numpy()}
|
||||||
"logits": logits.numpy(),
|
|
||||||
}
|
|
||||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
def create_and_check_bert_for_token_classification(
|
def create_and_check_bert_for_token_classification(
|
||||||
@@ -217,7 +238,11 @@ class TFBertModelTester:
|
|||||||
):
|
):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = TFBertForTokenClassification(config=config)
|
model = TFBertForTokenClassification(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
(logits,) = model(inputs)
|
(logits,) = model(inputs)
|
||||||
result = {
|
result = {
|
||||||
"logits": logits.numpy(),
|
"logits": logits.numpy(),
|
||||||
@@ -228,12 +253,14 @@ class TFBertModelTester:
|
|||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
model = TFBertForQuestionAnswering(config=config)
|
model = TFBertForQuestionAnswering(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
inputs = {
|
||||||
start_logits, end_logits = model(inputs)
|
"input_ids": input_ids,
|
||||||
result = {
|
"attention_mask": input_mask,
|
||||||
"start_logits": start_logits.numpy(),
|
"token_type_ids": token_type_ids,
|
||||||
"end_logits": end_logits.numpy(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
start_logits, end_logits = model(inputs)
|
||||||
|
result = {"start_logits": start_logits.numpy(), "end_logits": end_logits.numpy()}
|
||||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
|
||||||
@@ -285,6 +312,10 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_causal_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_bert_lm_head(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_multiple_choice(self):
|
def test_for_multiple_choice(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
@@ -93,6 +96,12 @@ class TFModelTesterMixin:
|
|||||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
|
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
|
||||||
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
|
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
|
||||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||||
|
elif model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||||
|
elif model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||||
|
elif model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
@@ -291,7 +300,7 @@ class TFModelTesterMixin:
|
|||||||
"decoder_input_ids": tf.keras.Input(
|
"decoder_input_ids": tf.keras.Input(
|
||||||
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
|
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
|
||||||
),
|
),
|
||||||
"inputs": tf.keras.Input(batch_shape=(2, 2000), name="inputs", dtype="int32"),
|
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||||
}
|
}
|
||||||
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
|
input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
|
||||||
@@ -325,7 +334,7 @@ class TFModelTesterMixin:
|
|||||||
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
|
outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "inputs", None,)
|
input_ids = inputs_keywords.pop("input_ids", None)
|
||||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||||
output_dict = outputs_dict[0].numpy()
|
output_dict = outputs_dict[0].numpy()
|
||||||
output_keywords = outputs_keywords[0].numpy()
|
output_keywords = outputs_keywords[0].numpy()
|
||||||
@@ -479,9 +488,9 @@ class TFModelTesterMixin:
|
|||||||
input_ids = inputs["input_ids"]
|
input_ids = inputs["input_ids"]
|
||||||
del inputs["input_ids"]
|
del inputs["input_ids"]
|
||||||
else:
|
else:
|
||||||
encoder_input_ids = inputs["inputs"]
|
encoder_input_ids = inputs["input_ids"]
|
||||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||||
del inputs["inputs"]
|
del inputs["input_ids"]
|
||||||
inputs.pop("decoder_input_ids", None)
|
inputs.pop("decoder_input_ids", None)
|
||||||
|
|
||||||
wte = model.get_input_embeddings()
|
wte = model.get_input_embeddings()
|
||||||
@@ -596,9 +605,15 @@ class TFModelTesterMixin:
|
|||||||
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
|
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
|
||||||
loss_size = tf.size(added_label)
|
loss_size = tf.size(added_label)
|
||||||
|
|
||||||
|
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
||||||
|
# if loss is causal lm loss, labels are shift, so that one label per batch
|
||||||
|
# is cut
|
||||||
|
loss_size = loss_size - self.model_tester.batch_size
|
||||||
|
|
||||||
# Test that model correctly compute the loss with kwargs
|
# Test that model correctly compute the loss with kwargs
|
||||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||||
input_ids = prepared_for_class.pop("input_ids")
|
input_ids = prepared_for_class.pop("input_ids")
|
||||||
|
|
||||||
loss = model(input_ids, **prepared_for_class)[0]
|
loss = model(input_ids, **prepared_for_class)[0]
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape, [loss_size])
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import DistilBertConfig, is_tf_available
|
from transformers import DistilBertConfig, is_tf_available
|
||||||
from transformers.testing_utils import require_tf
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
@@ -32,6 +32,7 @@ if is_tf_available():
|
|||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
TFDistilBertForTokenClassification,
|
TFDistilBertForTokenClassification,
|
||||||
TFDistilBertForMultipleChoice,
|
TFDistilBertForMultipleChoice,
|
||||||
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -118,9 +119,7 @@ class TFDistilBertModelTester:
|
|||||||
model = TFDistilBertForMaskedLM(config=config)
|
model = TFDistilBertForMaskedLM(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
(prediction_scores,) = model(inputs)
|
(prediction_scores,) = model(inputs)
|
||||||
result = {
|
result = {"prediction_scores": prediction_scores.numpy()}
|
||||||
"prediction_scores": prediction_scores.numpy(),
|
|
||||||
}
|
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||||
)
|
)
|
||||||
@@ -129,12 +128,12 @@ class TFDistilBertModelTester:
|
|||||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
model = TFDistilBertForQuestionAnswering(config=config)
|
model = TFDistilBertForQuestionAnswering(config=config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
inputs = {
|
||||||
start_logits, end_logits = model(inputs)
|
"input_ids": input_ids,
|
||||||
result = {
|
"attention_mask": input_mask,
|
||||||
"start_logits": start_logits.numpy(),
|
|
||||||
"end_logits": end_logits.numpy(),
|
|
||||||
}
|
}
|
||||||
|
start_logits, end_logits = model(inputs)
|
||||||
|
result = {"start_logits": start_logits.numpy(), "end_logits": end_logits.numpy()}
|
||||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||||
|
|
||||||
@@ -145,9 +144,7 @@ class TFDistilBertModelTester:
|
|||||||
model = TFDistilBertForSequenceClassification(config)
|
model = TFDistilBertForSequenceClassification(config)
|
||||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
(logits,) = model(inputs)
|
(logits,) = model(inputs)
|
||||||
result = {
|
result = {"logits": logits.numpy()}
|
||||||
"logits": logits.numpy(),
|
|
||||||
}
|
|
||||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||||
|
|
||||||
def create_and_check_distilbert_for_multiple_choice(
|
def create_and_check_distilbert_for_multiple_choice(
|
||||||
@@ -162,9 +159,7 @@ class TFDistilBertModelTester:
|
|||||||
"attention_mask": multiple_choice_input_mask,
|
"attention_mask": multiple_choice_input_mask,
|
||||||
}
|
}
|
||||||
(logits,) = model(inputs)
|
(logits,) = model(inputs)
|
||||||
result = {
|
result = {"logits": logits.numpy()}
|
||||||
"logits": logits.numpy(),
|
|
||||||
}
|
|
||||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
def create_and_check_distilbert_for_token_classification(
|
def create_and_check_distilbert_for_token_classification(
|
||||||
@@ -236,8 +231,8 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
# @slow
|
@slow
|
||||||
# def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in list(TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]):
|
||||||
# model = DistilBertModesss.from_pretrained(model_name)
|
model = TFDistilBertModel.from_pretrained(model_name)
|
||||||
# self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class TFT5ModelTester:
|
|||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.pad_token_id,
|
bos_token_id=self.pad_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (config, input_ids, input_mask, token_labels)
|
return (config, input_ids, input_mask, token_labels)
|
||||||
@@ -84,7 +85,7 @@ class TFT5ModelTester:
|
|||||||
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
|
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
|
||||||
model = TFT5Model(config=config)
|
model = TFT5Model(config=config)
|
||||||
inputs = {
|
inputs = {
|
||||||
"inputs": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"decoder_attention_mask": input_mask,
|
"decoder_attention_mask": input_mask,
|
||||||
}
|
}
|
||||||
@@ -115,7 +116,7 @@ class TFT5ModelTester:
|
|||||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||||
model = TFT5ForConditionalGeneration(config=config)
|
model = TFT5ForConditionalGeneration(config=config)
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"inputs": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"decoder_attention_mask": input_mask,
|
"decoder_attention_mask": input_mask,
|
||||||
}
|
}
|
||||||
@@ -209,7 +210,7 @@ class TFT5ModelTester:
|
|||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, input_mask, token_labels) = config_and_inputs
|
(config, input_ids, input_mask, token_labels) = config_and_inputs
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"inputs": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"decoder_attention_mask": input_mask,
|
"decoder_attention_mask": input_mask,
|
||||||
"use_cache": tf.convert_to_tensor([False]),
|
"use_cache": tf.convert_to_tensor([False]),
|
||||||
|
|||||||
Reference in New Issue
Block a user