[Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile (#5395)
* add first version of clm tf * make style * add more tests for bert * update tf clm loss * fix tests * correct tf ner script * add mlm loss * delete bogus file * clean tf auto model + add tests * finish adding clm loss everywhere * fix training in distilbert * fix flake8 * save intermediate * fix tf t5 naming * remove prints * finish up * up * fix tf gpt2 * fix new test utils import * fix flake8 * keep backward compatibility * Update src/transformers/modeling_tf_albert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_roberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_mobilebert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_distilbert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply sylvains suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
33e43edddc
commit
4dc65591b5
@@ -24,6 +24,7 @@ import tensorflow as tf
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
@@ -542,7 +543,7 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
|
||||
(linear layer with weights tied to the input embeddings). """,
|
||||
CTRL_START_DOCSTRING,
|
||||
)
|
||||
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
||||
@@ -561,8 +562,26 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
|
||||
def call(self, inputs, **kwargs):
|
||||
def call(
|
||||
self,
|
||||
inputs,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the cross entropy classification loss.
|
||||
Indices should be in ``[0, ..., config.vocab_size - 1]``.
|
||||
|
||||
Return:
|
||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
|
||||
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
@@ -583,11 +602,37 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
labels = inputs[10] if len(inputs) > 10 else labels
|
||||
if len(inputs) > 10:
|
||||
inputs = inputs[:10]
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
labels = inputs.pop("labels", labels)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
inputs,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
outputs = (lm_logits,) + transformer_outputs[1:]
|
||||
outputs = (logits,) + transformer_outputs[1:]
|
||||
if labels is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = labels[:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # lm_logits, presents, (all hidden_states), (attentions)
|
||||
|
||||
Reference in New Issue
Block a user