[Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile (#5395)

* add first version of clm tf

* make style

* add more tests for bert

* update tf clm loss

* fix tests

* correct tf ner script

* add mlm loss

* delete bogus file

* clean tf auto model + add tests

* finish adding clm loss everywhere

* fix training in distilbert

* fix flake8

* save intermediate

* fix tf t5 naming

* remove prints

* finish up

* up

* fix tf gpt2

* fix new test utils import

* fix flake8

* keep backward compatibility

* Update src/transformers/modeling_tf_albert.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_auto.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_electra.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_roberta.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_mobilebert.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_auto.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_bert.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_distilbert.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply sylvains suggestions

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-07-07 18:15:53 +02:00
committed by GitHub
parent 33e43edddc
commit 4dc65591b5
23 changed files with 1516 additions and 315 deletions

View File

@@ -24,6 +24,7 @@ import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
@@ -542,7 +543,7 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
(linear layer with weights tied to the input embeddings). """,
CTRL_START_DOCSTRING,
)
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFCTRLMainLayer(config, name="transformer")
@@ -561,8 +562,26 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
def call(self, inputs, **kwargs):
def call(
self,
inputs,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the cross entropy classification loss.
Indices should be in ``[0, ..., config.vocab_size - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
@@ -583,11 +602,37 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
transformer_outputs = self.transformer(inputs, **kwargs)
if isinstance(inputs, (tuple, list)):
labels = inputs[10] if len(inputs) > 10 else labels
if len(inputs) > 10:
inputs = inputs[:10]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
transformer_outputs = self.transformer(
inputs,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
outputs = (logits,) + transformer_outputs[1:]
if labels is not None:
# shift labels to the left and cut last logit token
logits = logits[:, :-1]
labels = labels[:, 1:]
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # lm_logits, presents, (all hidden_states), (attentions)