adding TF 2.0 model

This commit is contained in:
thomwolf
2019-10-09 11:07:43 +02:00
parent 45dc04f33d
commit c56d921dda
6 changed files with 430 additions and 213 deletions

View File

@@ -25,7 +25,7 @@ import numpy as np
import tensorflow as tf
from .configuration_ctrl import CTRLConfig
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, TFSharedEmbeddings
from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
@@ -33,12 +33,19 @@ logger = logging.getLogger(__name__)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"}
def load_ctrl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
# build the network
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def angle_defn(pos, i, d_model_size):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model_size))
return pos * angle_rates
def positional_encoding(position, d_model_size, dtype):
def positional_encoding(position, d_model_size):
# create the sinusoidal pattern for the positional encoding
angle_rads = angle_defn(np.arange(position)[:, np.newaxis],
np.arange(d_model_size)[np.newaxis, :],
@@ -47,14 +54,15 @@ def positional_encoding(position, d_model_size, dtype):
sines = np.sin(angle_rads[:, 0::2])
cosines = np.cos(angle_rads[:, 1::2])
pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)
return pos_encoding
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
# calculate attention
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
dk = tf.cast(shape_list(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
@@ -94,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False)
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
batch_size = q.shape[0]
@@ -124,31 +132,34 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def point_wise_feed_forward_network(d_model_size, dff):
return tf.keras.Sequential([tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model_size)])
def point_wise_feed_forward_network(d_model_size, dff, name=""):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu', name="0"),
tf.keras.layers.Dense(d_model_size, name="2")
], name="ffn")
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False, **kwargs):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs):
super(TFEncoderLayer, self).__init__(**kwargs)
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions)
self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.multi_head_attention = TFMultiHeadAttention(d_model_size,
num_heads,
output_attentions,
name="multi_head_attention")
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask = inputs
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(normed, normed, normed, mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask)
attn_outputs = self.multi_head_attention([normed, normed, normed, mask, layer_past,
attention_mask, head_mask], training=training)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training)
out1 = x + attn_output
@@ -162,6 +173,152 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return outputs
class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFCTRLMainLayer, self).__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
self.output_attentions = config.output_attentions
self.w = TFSharedEmbeddings(config.vocab_size,
config.n_embd,
initializer_range=config.initializer_range,
name="w")
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFEncoderLayer(config.n_embd,
config.n_head,
config.dff,
config.resid_pdrop,
config.layer_norm_epsilon,
config.output_attentions,
name='h_._{}'.format(i)) for i in range(config.n_layer)]
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise NotImplementedError
def call(self, inputs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids')
past = inputs.get('past', past)
attention_mask = inputs.get('attention_mask', attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids)
position_ids = inputs.get('position_ids', position_ids)
head_mask = inputs.get('head_mask', head_mask)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = shape_list(past[0][0])[-2]
if position_ids is None:
position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
# Attention mask.
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = tf.cast(attention_mask, tf.float32)
attention_mask = (1.0 - attention_mask) * -10000.0
else:
attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_layers
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.w(token_type_ids, mode='embedding')
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
else:
token_type_embeds = 0
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
inputs_embeds = self.w(input_ids)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
pos_embeds = tf.gather(self.pos_encoding, position_ids)
hidden_states = inputs_embeds + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = ()
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs
class TFCTRLPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
@@ -169,20 +326,7 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class = CTRLConfig
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
load_pt_weights = load_bert_pt_weights_in_tf2
def _init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
load_pt_weights = load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING = r""" CTRL model was proposed in
@@ -240,172 +384,68 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
class TFCTRLModel(TFCTRLPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import CTRLTokenizer, TFCTRLModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
model = TFCTRLModel.from_pretrained('ctrl')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config, **kwargs):
super(TFCTRLModel, self).__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
self.output_attentions = config.output_attentions
def __init__(self, config, *inputs, **kwargs):
super(TFCTRLModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFCTRLMainLayer(config, name='transformer')
self.w = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head,
config.dff,
config.resid_pdrop,
config.output_attentions) for _ in range(config.n_layer)])
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
self.w = self._get_resized_embeddings(self.w, new_num_tokens)
return self.w
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
x = self.w(input_ids)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
x *= np.sqrt(self.d_model_size)
pos_x = self.pos_encoding[position_ids, :].to(x.device)
x += pos_x
x = self.dropout(x)
output_shape = input_shape + (x.size(-1),)
presents = ()
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x.view(*output_shape),)
outputs = h(x,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i])
x, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
x = self.layernorm(x)
x = x.view(*output_shape)
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x,)
outputs = (x, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs)
return outputs
class TFCTRLLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super(TFCTRLLMHead, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.input_embeddings = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='bias')
super(TFCTRLLMHead, self).build(input_shape)
def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
return hidden_states
@add_start_docstrings("""The CTRL Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, CTRL_START_DOCSTRING, CTRL_INPUTS_DOCSTRING)
class CTRLLMHeadModel(CTRLPreTrainedModel):
class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
@@ -423,53 +463,28 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Examples::
import torch
from transformers import CTRLTokenizer, CTRLLMHeadModel
from transformers import CTRLTokenizer, TFCTRLLMHeadModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLLMHeadModel.from_pretrained('ctrl')
model = TFCTRLLMHeadModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
"""
def __init__(self, config):
super(CTRLLMHeadModel, self).__init__(config)
self.transformer = CTRLModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)
def __init__(self, config, *inputs, **kwargs):
super(TFCTRLLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFCTRLMainLayer(config, name='transformer')
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
return outputs # lm_logits, presents, (all hidden_states), (attentions)