TF: CTRL with native embedding layers (#23456)
This commit is contained in:
@@ -15,10 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" TF 2.0 CTRL model."""
|
""" TF 2.0 CTRL model."""
|
||||||
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -30,7 +28,6 @@ from ...modeling_tf_utils import (
|
|||||||
TFModelInputType,
|
TFModelInputType,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
TFSharedEmbeddings,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
unpack_inputs,
|
unpack_inputs,
|
||||||
@@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
|
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
|
||||||
|
|
||||||
self.w = TFSharedEmbeddings(
|
self.w = tf.keras.layers.Embedding(
|
||||||
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="w"
|
input_dim=config.vocab_size,
|
||||||
|
output_dim=config.n_embd,
|
||||||
|
embeddings_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="w",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
|
self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||||
@@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.w
|
return self.w
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, new_embeddings):
|
||||||
self.w.weight = value
|
self.w = new_embeddings
|
||||||
self.w.vocab_size = shape_list(value)[0]
|
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
"""
|
"""
|
||||||
@@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length))
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
|
||||||
token_type_embeds = self.w(token_type_ids, mode="embedding")
|
token_type_embeds = self.w(token_type_ids)
|
||||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
|
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
|
||||||
else:
|
else:
|
||||||
token_type_embeds = tf.constant(0.0)
|
token_type_embeds = tf.constant(0.0)
|
||||||
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
check_embeddings_within_bounds(input_ids, self.w.vocab_size)
|
check_embeddings_within_bounds(input_ids, self.w.input_dim)
|
||||||
inputs_embeds = self.w(input_ids, mode="embedding")
|
inputs_embeds = self.w(input_ids)
|
||||||
seq_len = input_shape[-1]
|
seq_len = input_shape[-1]
|
||||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||||
|
|
||||||
@@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class TFCTRLLMHead(tf.keras.layers.Layer):
|
class TFCTRLBiasLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, input_embeddings, **kwargs):
|
"""
|
||||||
super().__init__(**kwargs)
|
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
|
||||||
self.config = config
|
so all weights have to be registered in a layer.
|
||||||
# CTRL has numerical issues in XLA generate
|
"""
|
||||||
self.supports_xla_generation = False
|
|
||||||
|
|
||||||
# The output weights are the same as the input embeddings, but there is
|
def __init__(self, shape, initializer, trainable, name, **kwargs):
|
||||||
# an output-only bias for each token.
|
super().__init__(name=name, **kwargs)
|
||||||
self.input_embeddings = input_embeddings
|
self.shape = shape
|
||||||
|
self.initializer = initializer
|
||||||
|
self.trainable = trainable
|
||||||
|
|
||||||
def build(self, input_shape=None):
|
def build(self, input_shape):
|
||||||
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
self.bias = self.add_weight(
|
||||||
|
name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable
|
||||||
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def call(self, x):
|
||||||
return self.input_embeddings
|
return x + self.bias
|
||||||
|
|
||||||
def set_output_embeddings(self, value):
|
|
||||||
self.input_embeddings.weight = value
|
|
||||||
self.input_embeddings.vocab_size = shape_list(value)[0]
|
|
||||||
|
|
||||||
def get_bias(self):
|
|
||||||
return {"bias": self.bias}
|
|
||||||
|
|
||||||
def set_bias(self, value):
|
|
||||||
self.bias = value["bias"]
|
|
||||||
self.config.vocab_size = shape_list(value["bias"])[0]
|
|
||||||
|
|
||||||
def call(self, hidden_states):
|
|
||||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
|
||||||
hidden_states = hidden_states + self.bias
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
self.transformer = TFCTRLMainLayer(config, name="transformer")
|
||||||
|
self.bias_layer = TFCTRLBiasLayer(
|
||||||
|
name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True
|
||||||
|
)
|
||||||
|
|
||||||
self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")
|
def get_output_embeddings(self):
|
||||||
# CTRL has numerical issues in XLA generate
|
return self.get_input_embeddings()
|
||||||
self.supports_xla_generation = False
|
|
||||||
|
|
||||||
def get_lm_head(self):
|
def set_output_embeddings(self, value):
|
||||||
return self.lm_head
|
self.set_input_embeddings(value)
|
||||||
|
|
||||||
def get_prefix_bias_name(self):
|
def get_bias(self):
|
||||||
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
|
return {"lm_head.bias": self.bias_layer.bias}
|
||||||
return self.name + "/" + self.lm_head.name
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
|
def set_bias(self, value):
|
||||||
|
# Replaces the existing layers containing bias for correct (de)serialization.
|
||||||
|
vocab_size = value["lm_head.bias"].shape[-1]
|
||||||
|
self.bias_layer = TFCTRLBiasLayer(
|
||||||
|
name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True
|
||||||
|
)
|
||||||
|
self.bias_layer.build(None)
|
||||||
|
self.bias_layer.bias.assign(value["lm_head.bias"])
|
||||||
|
|
||||||
|
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
|
||||||
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
input_ids = tf.expand_dims(input_ids[:, -1], -1)
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||||
|
|
||||||
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": inputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||||
@@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.bias_layer(logits)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
model.build() # may be needed for the get_bias() call below
|
||||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||||
|
|
||||||
if model_class in list_lm_models:
|
if model_class in list_lm_models:
|
||||||
|
|||||||
Reference in New Issue
Block a user