Move TF building to an actual build() method (#23760)
* A fun new PR where I break the entire codebase again * A fun new PR where I break the entire codebase again * Handle cross-attention * Move calls to model(model.dummy_inputs) to the new build() method * Seeing what fails with the build context thing * make fix-copies * Let's see what fails with new build methods * Fix the pytorch crossload build calls * Fix the overridden build methods in vision_text_dual_encoder * Make sure all our build methods set self.built or call super().build(), which also sets it * make fix-copies * Remove finished TODO * Tentatively remove unneeded (?) line * Transpose b in deberta correctly and remove unused threading local * Get rid of build_with_dummies and all it stands for * Rollback some changes to TF-PT crossloading * Correctly call super().build()
This commit is contained in:
@@ -341,9 +341,6 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
|
|
||||||
K.batch_set_value(weight_value_tuples)
|
K.batch_set_value(weight_value_tuples)
|
||||||
|
|
||||||
if tf_inputs is not None:
|
|
||||||
tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
|
||||||
|
|
||||||
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
|
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
|
||||||
|
|
||||||
unexpected_keys = list(all_pytorch_weights)
|
unexpected_keys = list(all_pytorch_weights)
|
||||||
|
|||||||
@@ -40,7 +40,12 @@ from .activations_tf import get_tf_activation
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import GenerationConfig, TFGenerationMixin
|
from .generation import GenerationConfig, TFGenerationMixin
|
||||||
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
|
from .tf_utils import (
|
||||||
|
expand_1d,
|
||||||
|
load_attributes_from_hdf5_group,
|
||||||
|
save_attributes_to_hdf5_group,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
@@ -69,11 +74,14 @@ from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
|||||||
if parse(tf.__version__).minor >= 13:
|
if parse(tf.__version__).minor >= 13:
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
from keras.__internal__ import KerasTensor
|
from keras.__internal__ import KerasTensor
|
||||||
|
from keras.engine.base_layer_utils import call_context
|
||||||
elif parse(tf.__version__).minor >= 11:
|
elif parse(tf.__version__).minor >= 11:
|
||||||
from keras import backend as K
|
from keras import backend as K
|
||||||
|
from keras.engine.base_layer_utils import call_context
|
||||||
from keras.engine.keras_tensor import KerasTensor
|
from keras.engine.keras_tensor import KerasTensor
|
||||||
else:
|
else:
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
|
from tensorflow.python.keras.engine import call_context
|
||||||
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
||||||
|
|
||||||
|
|
||||||
@@ -1140,6 +1148,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
"""
|
"""
|
||||||
return "tf"
|
return "tf"
|
||||||
|
|
||||||
|
def build(self, input_shape=None):
|
||||||
|
if self.built or call_context().in_call:
|
||||||
|
self.built = True
|
||||||
|
else:
|
||||||
|
self(self.dummy_inputs, training=False)
|
||||||
|
self.built = True
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
@@ -1867,7 +1882,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
main_layer.set_input_embeddings(value)
|
main_layer.set_input_embeddings(value)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.info("Building the model")
|
logger.info("Building the model")
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
main_layer.set_input_embeddings(value)
|
main_layer.set_input_embeddings(value)
|
||||||
|
|
||||||
def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
|
def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
|
||||||
@@ -1884,7 +1899,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
return lm_head.get_output_embeddings()
|
return lm_head.get_output_embeddings()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.info("Building the model")
|
logger.info("Building the model")
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
|
|
||||||
return lm_head().get_output_embeddings()
|
return lm_head().get_output_embeddings()
|
||||||
|
|
||||||
@@ -1904,7 +1919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
lm_head.set_output_embeddings(value)
|
lm_head.set_output_embeddings(value)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.info("Building the model")
|
logger.info("Building the model")
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
lm_head.set_output_embeddings(value)
|
lm_head.set_output_embeddings(value)
|
||||||
|
|
||||||
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
|
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
|
||||||
@@ -1942,7 +1957,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
try:
|
try:
|
||||||
return lm_head.get_bias()
|
return lm_head.get_bias()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
|
|
||||||
return lm_head.get_bias()
|
return lm_head.get_bias()
|
||||||
return None
|
return None
|
||||||
@@ -1960,7 +1975,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
try:
|
try:
|
||||||
lm_head.set_bias(value)
|
lm_head.set_bias(value)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
lm_head.set_bias(value)
|
lm_head.set_bias(value)
|
||||||
|
|
||||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||||
@@ -2047,7 +2062,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# The reason why the attributes don't exist might be
|
# The reason why the attributes don't exist might be
|
||||||
# because the model is not built, so retry getting
|
# because the model is not built, so retry getting
|
||||||
# the argument after building the model
|
# the argument after building the model
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
embeds = getattr(embedding_layer, "weight", None)
|
embeds = getattr(embedding_layer, "weight", None)
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
@@ -2870,9 +2885,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# we might need to extend the variable scope for composite models
|
# we might need to extend the variable scope for composite models
|
||||||
if load_weight_prefix is not None:
|
if load_weight_prefix is not None:
|
||||||
with tf.compat.v1.variable_scope(load_weight_prefix):
|
with tf.compat.v1.variable_scope(load_weight_prefix):
|
||||||
model(model.dummy_inputs) # build the network with dummy inputs
|
model.build() # build the network with dummy inputs
|
||||||
else:
|
else:
|
||||||
model(model.dummy_inputs) # build the network with dummy inputs
|
model.build() # build the network with dummy inputs
|
||||||
|
|
||||||
if safetensors_from_pt:
|
if safetensors_from_pt:
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||||
@@ -2925,8 +2940,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
|
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
|
||||||
)
|
)
|
||||||
|
|
||||||
model(model.dummy_inputs) # Make sure restore ops are run
|
|
||||||
|
|
||||||
if cls._keys_to_ignore_on_load_missing is not None:
|
if cls._keys_to_ignore_on_load_missing is not None:
|
||||||
for pat in cls._keys_to_ignore_on_load_missing:
|
for pat in cls._keys_to_ignore_on_load_missing:
|
||||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||||
|
|||||||
@@ -258,6 +258,7 @@ class TFBlipVisionEmbeddings(tf.keras.layers.Layer):
|
|||||||
trainable=True,
|
trainable=True,
|
||||||
name="position_embedding",
|
name="position_embedding",
|
||||||
)
|
)
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
|
def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
|
||||||
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
|
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
|
||||||
@@ -282,7 +283,7 @@ class TFBlipTextEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
with tf.name_scope("token_embedding"):
|
with tf.name_scope("token_embedding"):
|
||||||
self.weight = self.add_weight(
|
self.weight = self.add_weight(
|
||||||
shape=(self.config.vocab_size, self.embed_dim),
|
shape=(self.config.vocab_size, self.embed_dim),
|
||||||
@@ -757,13 +758,14 @@ class TFBlipMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape=None):
|
||||||
self.logit_scale = self.add_weight(
|
self.logit_scale = self.add_weight(
|
||||||
name="logit_scale",
|
name="logit_scale",
|
||||||
shape=[],
|
shape=[],
|
||||||
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
|
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
|
||||||
trainable=True,
|
trainable=True,
|
||||||
)
|
)
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
def call(
|
def call(
|
||||||
|
|||||||
@@ -543,8 +543,9 @@ class TFBlipTextLMPredictionHead(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape=None):
|
||||||
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
|
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ class TFCLIPVisionEmbeddings(tf.keras.layers.Layer):
|
|||||||
name="patch_embedding",
|
name="patch_embedding",
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
factor = self.config.initializer_factor
|
factor = self.config.initializer_factor
|
||||||
|
|
||||||
self.class_embedding = self.add_weight(
|
self.class_embedding = self.add_weight(
|
||||||
@@ -204,7 +204,7 @@ class TFCLIPTextEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
with tf.name_scope("token_embedding"):
|
with tf.name_scope("token_embedding"):
|
||||||
self.weight = self.add_weight(
|
self.weight = self.add_weight(
|
||||||
shape=(self.config.vocab_size, self.embed_dim),
|
shape=(self.config.vocab_size, self.embed_dim),
|
||||||
@@ -739,7 +739,7 @@ class TFCLIPMainLayer(tf.keras.layers.Layer):
|
|||||||
name="text_projection",
|
name="text_projection",
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
self.logit_scale = self.add_weight(
|
self.logit_scale = self.add_weight(
|
||||||
shape=(1,),
|
shape=(1,),
|
||||||
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
|
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
|
|||||||
self.group_in_dim = self.input_size // self.num_groups
|
self.group_in_dim = self.input_size // self.num_groups
|
||||||
self.group_out_dim = self.output_size // self.num_groups
|
self.group_out_dim = self.output_size // self.num_groups
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape=None):
|
||||||
self.kernel = self.add_weight(
|
self.kernel = self.add_weight(
|
||||||
"kernel",
|
"kernel",
|
||||||
shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
|
shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
|
||||||
@@ -357,6 +357,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
|
|||||||
self.bias = self.add_weight(
|
self.bias = self.add_weight(
|
||||||
"bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True
|
"bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True
|
||||||
)
|
)
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class TFConvNextLayer(tf.keras.layers.Layer):
|
|||||||
else tf.keras.layers.Activation("linear", name="drop_path")
|
else tf.keras.layers.Activation("linear", name="drop_path")
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
|
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
|
||||||
self.layer_scale_parameter = (
|
self.layer_scale_parameter = (
|
||||||
self.add_weight(
|
self.add_weight(
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
|
|||||||
# an output-only bias for each token.
|
# an output-only bias for each token.
|
||||||
self.input_embeddings = input_embeddings
|
self.input_embeddings = input_embeddings
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape=None):
|
||||||
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
|
|||||||
@@ -464,7 +464,7 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.init_values = config.layer_scale_init_value
|
self.init_values = config.layer_scale_init_value
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
if self.init_values > 0:
|
if self.init_values > 0:
|
||||||
self.lambda_1 = self.add_weight(
|
self.lambda_1 = self.add_weight(
|
||||||
shape=(self.config.hidden_size),
|
shape=(self.config.hidden_size),
|
||||||
|
|||||||
@@ -593,11 +593,10 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def linear(w, b, x):
|
def linear(w, b, x):
|
||||||
return tf.cond(
|
out = tf.matmul(x, w, transpose_b=True)
|
||||||
b is not None,
|
if b is not None:
|
||||||
lambda: tf.matmul(x, w, transpose_b=True) + tf.transpose(b),
|
out += tf.transpose(b)
|
||||||
lambda: tf.matmul(x, w, transpose_b=True),
|
return out
|
||||||
)
|
|
||||||
|
|
||||||
ws = tf.split(
|
ws = tf.split(
|
||||||
tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
|
tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
|
||||||
|
|||||||
@@ -532,7 +532,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
|
|||||||
try:
|
try:
|
||||||
return self.ctx_encoder.bert_model.get_input_embeddings()
|
return self.ctx_encoder.bert_model.get_input_embeddings()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
return self.ctx_encoder.bert_model.get_input_embeddings()
|
return self.ctx_encoder.bert_model.get_input_embeddings()
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@@ -613,7 +613,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
|
|||||||
try:
|
try:
|
||||||
return self.question_encoder.bert_model.get_input_embeddings()
|
return self.question_encoder.bert_model.get_input_embeddings()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
return self.question_encoder.bert_model.get_input_embeddings()
|
return self.question_encoder.bert_model.get_input_embeddings()
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@@ -693,7 +693,7 @@ class TFDPRReader(TFDPRPretrainedReader):
|
|||||||
try:
|
try:
|
||||||
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self(self.dummy_inputs)
|
self.build()
|
||||||
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
return self.span_predictor.encoder.bert_model.get_input_embeddings()
|
||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
|
|||||||
@@ -538,7 +538,7 @@ class TFGroupViTTextEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
def build(self, input_shape: tf.TensorShape = None):
|
||||||
with tf.name_scope("token_embedding"):
|
with tf.name_scope("token_embedding"):
|
||||||
self.weight = self.add_weight(
|
self.weight = self.add_weight(
|
||||||
shape=(self.config.vocab_size, self.embed_dim),
|
shape=(self.config.vocab_size, self.embed_dim),
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ class TFLEDLearnedPositionalEmbedding(tf.keras.layers.Embedding):
|
|||||||
class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, layer_id, **kwargs):
|
def __init__(self, config, layer_id, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -191,6 +192,16 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.one_sided_attn_window_size = attention_window // 2
|
self.one_sided_attn_window_size = attention_window // 2
|
||||||
|
|
||||||
|
def build(self, input_shape=None):
|
||||||
|
if not self.built:
|
||||||
|
with tf.name_scope("query_global"):
|
||||||
|
self.query_global.build((self.config.hidden_size,))
|
||||||
|
with tf.name_scope("key_global"):
|
||||||
|
self.key_global.build((self.config.hidden_size,))
|
||||||
|
with tf.name_scope("value_global"):
|
||||||
|
self.value_global.build((self.config.hidden_size,))
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -271,9 +282,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
) = self._get_global_attn_indices(is_index_global_attn)
|
) = self._get_global_attn_indices(is_index_global_attn)
|
||||||
|
|
||||||
# this function is only relevant for global attention
|
# this function is only relevant for global attention
|
||||||
attn_scores = tf.cond(
|
if is_global_attn:
|
||||||
is_global_attn,
|
attn_scores = self._concat_with_global_key_attn_probs(
|
||||||
lambda: self._concat_with_global_key_attn_probs(
|
|
||||||
attn_scores=attn_scores,
|
attn_scores=attn_scores,
|
||||||
query_vectors=query_vectors,
|
query_vectors=query_vectors,
|
||||||
key_vectors=key_vectors,
|
key_vectors=key_vectors,
|
||||||
@@ -281,26 +291,24 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||||
),
|
)
|
||||||
lambda: attn_scores,
|
|
||||||
)
|
|
||||||
attn_probs = stable_softmax(attn_scores, axis=-1)
|
attn_probs = stable_softmax(attn_scores, axis=-1)
|
||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
# Make sure to create a mask with the proper shape:
|
# Make sure to create a mask with the proper shape:
|
||||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
masked_index = tf.cond(
|
if is_global_attn:
|
||||||
is_global_attn,
|
masked_index = tf.tile(
|
||||||
lambda: tf.tile(
|
|
||||||
is_index_masked[:, :, None, None],
|
is_index_masked[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
),
|
)
|
||||||
lambda: tf.tile(
|
else:
|
||||||
|
masked_index = tf.tile(
|
||||||
is_index_masked[:, :, None, None],
|
is_index_masked[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
masked_index,
|
masked_index,
|
||||||
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
|
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
|
||||||
@@ -324,19 +332,19 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
|
|
||||||
# if global attention, compute sum of global and local attn
|
# if global attention, compute sum of global and local attn
|
||||||
attn_output = tf.cond(
|
|
||||||
is_global_attn,
|
if is_global_attn:
|
||||||
lambda: self._compute_attn_output_with_global_indices(
|
attn_output = self._compute_attn_output_with_global_indices(
|
||||||
value_vectors=value_vectors,
|
value_vectors=value_vectors,
|
||||||
attn_probs=attn_probs,
|
attn_probs=attn_probs,
|
||||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||||
),
|
)
|
||||||
lambda: self._sliding_chunks_matmul_attn_probs_value(
|
else:
|
||||||
|
attn_output = self._sliding_chunks_matmul_attn_probs_value(
|
||||||
attn_probs, value_vectors, self.one_sided_attn_window_size
|
attn_probs, value_vectors, self.one_sided_attn_window_size
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||||
@@ -345,10 +353,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
|
|
||||||
# compute value for global attention and overwrite to attention output
|
# compute value for global attention and overwrite to attention output
|
||||||
# TODO: remove the redundant computation
|
if is_global_attn:
|
||||||
attn_output, global_attn_probs = tf.cond(
|
attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
||||||
is_global_attn,
|
|
||||||
lambda: self._compute_global_attn_output_from_hidden(
|
|
||||||
attn_output=attn_output,
|
attn_output=attn_output,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||||
@@ -358,25 +364,25 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
|||||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
training=training,
|
training=training,
|
||||||
),
|
)
|
||||||
lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))),
|
else:
|
||||||
)
|
# Leave attn_output unchanged
|
||||||
|
global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))
|
||||||
|
|
||||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||||
# Make sure to create a mask with the proper shape:
|
# Make sure to create a mask with the proper shape:
|
||||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
masked_global_attn_index = tf.cond(
|
if is_global_attn:
|
||||||
is_global_attn,
|
masked_global_attn_index = tf.tile(
|
||||||
lambda: tf.tile(
|
|
||||||
is_index_global_attn[:, :, None, None],
|
is_index_global_attn[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
),
|
)
|
||||||
lambda: tf.tile(
|
else:
|
||||||
|
masked_global_attn_index = tf.tile(
|
||||||
is_index_global_attn[:, :, None, None],
|
is_index_global_attn[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
masked_global_attn_index,
|
masked_global_attn_index,
|
||||||
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
|
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
|
||||||
@@ -1864,13 +1870,10 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||||||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
|
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
|
||||||
|
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
|
if padding_len > 0:
|
||||||
def pad_embeddings():
|
|
||||||
input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
|
input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
|
||||||
inputs_embeds_padding = self.embed_tokens(input_ids_padding)
|
inputs_embeds_padding = self.embed_tokens(input_ids_padding)
|
||||||
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
|
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
|
||||||
|
|
||||||
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
|
|
||||||
|
|
||||||
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
|
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
|
||||||
|
|
||||||
|
|||||||
@@ -652,6 +652,7 @@ class TFLongformerSelfOutput(tf.keras.layers.Layer):
|
|||||||
class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, layer_id, **kwargs):
|
def __init__(self, config, layer_id, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -708,6 +709,16 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.one_sided_attn_window_size = attention_window // 2
|
self.one_sided_attn_window_size = attention_window // 2
|
||||||
|
|
||||||
|
def build(self, input_shape=None):
|
||||||
|
if not self.built:
|
||||||
|
with tf.name_scope("query_global"):
|
||||||
|
self.query_global.build((self.config.hidden_size,))
|
||||||
|
with tf.name_scope("key_global"):
|
||||||
|
self.key_global.build((self.config.hidden_size,))
|
||||||
|
with tf.name_scope("value_global"):
|
||||||
|
self.value_global.build((self.config.hidden_size,))
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -788,9 +799,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
) = self._get_global_attn_indices(is_index_global_attn)
|
) = self._get_global_attn_indices(is_index_global_attn)
|
||||||
|
|
||||||
# this function is only relevant for global attention
|
# this function is only relevant for global attention
|
||||||
attn_scores = tf.cond(
|
if is_global_attn:
|
||||||
is_global_attn,
|
attn_scores = self._concat_with_global_key_attn_probs(
|
||||||
lambda: self._concat_with_global_key_attn_probs(
|
|
||||||
attn_scores=attn_scores,
|
attn_scores=attn_scores,
|
||||||
query_vectors=query_vectors,
|
query_vectors=query_vectors,
|
||||||
key_vectors=key_vectors,
|
key_vectors=key_vectors,
|
||||||
@@ -798,26 +808,24 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||||
),
|
)
|
||||||
lambda: attn_scores,
|
|
||||||
)
|
|
||||||
attn_probs = stable_softmax(attn_scores, axis=-1)
|
attn_probs = stable_softmax(attn_scores, axis=-1)
|
||||||
|
|
||||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||||
# Make sure to create a mask with the proper shape:
|
# Make sure to create a mask with the proper shape:
|
||||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
masked_index = tf.cond(
|
if is_global_attn:
|
||||||
is_global_attn,
|
masked_index = tf.tile(
|
||||||
lambda: tf.tile(
|
|
||||||
is_index_masked[:, :, None, None],
|
is_index_masked[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
),
|
)
|
||||||
lambda: tf.tile(
|
else:
|
||||||
|
masked_index = tf.tile(
|
||||||
is_index_masked[:, :, None, None],
|
is_index_masked[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
masked_index,
|
masked_index,
|
||||||
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
|
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
|
||||||
@@ -841,19 +849,19 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
|
|
||||||
# if global attention, compute sum of global and local attn
|
# if global attention, compute sum of global and local attn
|
||||||
attn_output = tf.cond(
|
|
||||||
is_global_attn,
|
if is_global_attn:
|
||||||
lambda: self._compute_attn_output_with_global_indices(
|
attn_output = self._compute_attn_output_with_global_indices(
|
||||||
value_vectors=value_vectors,
|
value_vectors=value_vectors,
|
||||||
attn_probs=attn_probs,
|
attn_probs=attn_probs,
|
||||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||||
),
|
)
|
||||||
lambda: self._sliding_chunks_matmul_attn_probs_value(
|
else:
|
||||||
|
attn_output = self._sliding_chunks_matmul_attn_probs_value(
|
||||||
attn_probs, value_vectors, self.one_sided_attn_window_size
|
attn_probs, value_vectors, self.one_sided_attn_window_size
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
|
||||||
@@ -862,10 +870,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
|
||||||
|
|
||||||
# compute value for global attention and overwrite to attention output
|
# compute value for global attention and overwrite to attention output
|
||||||
# TODO: remove the redundant computation
|
if is_global_attn:
|
||||||
attn_output, global_attn_probs = tf.cond(
|
attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
|
||||||
is_global_attn,
|
|
||||||
lambda: self._compute_global_attn_output_from_hidden(
|
|
||||||
attn_output=attn_output,
|
attn_output=attn_output,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||||
@@ -875,25 +881,25 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
|||||||
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
|
||||||
is_index_masked=is_index_masked,
|
is_index_masked=is_index_masked,
|
||||||
training=training,
|
training=training,
|
||||||
),
|
)
|
||||||
lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))),
|
else:
|
||||||
)
|
# Leave attn_output unchanged
|
||||||
|
global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))
|
||||||
|
|
||||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||||
# Make sure to create a mask with the proper shape:
|
# Make sure to create a mask with the proper shape:
|
||||||
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
|
||||||
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
|
||||||
masked_global_attn_index = tf.cond(
|
if is_global_attn:
|
||||||
is_global_attn,
|
masked_global_attn_index = tf.tile(
|
||||||
lambda: tf.tile(
|
|
||||||
is_index_global_attn[:, :, None, None],
|
is_index_global_attn[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
|
||||||
),
|
)
|
||||||
lambda: tf.tile(
|
else:
|
||||||
|
masked_global_attn_index = tf.tile(
|
||||||
is_index_global_attn[:, :, None, None],
|
is_index_global_attn[:, :, None, None],
|
||||||
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
attn_probs = tf.where(
|
attn_probs = tf.where(
|
||||||
masked_global_attn_index,
|
masked_global_attn_index,
|
||||||
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
|
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
|
||||||
@@ -1828,13 +1834,10 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
|
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
|
||||||
|
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
|
if padding_len > 0:
|
||||||
def pad_embeddings():
|
|
||||||
input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
|
input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
|
||||||
inputs_embeds_padding = self.embeddings(input_ids_padding)
|
inputs_embeds_padding = self.embeddings(input_ids_padding)
|
||||||
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
|
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
|
||||||
|
|
||||||
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
|
|
||||||
|
|
||||||
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
|
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
|
||||||
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
|
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ class TFNoNorm(tf.keras.layers.Layer):
|
|||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros")
|
self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros")
|
||||||
self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones")
|
self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones")
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs: tf.Tensor):
|
def call(self, inputs: tf.Tensor):
|
||||||
return inputs * self.weight + self.bias
|
return inputs * self.weight + self.bias
|
||||||
|
|||||||
@@ -581,6 +581,7 @@ class TFSamPositionalEmbedding(tf.keras.layers.Layer):
|
|||||||
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
|
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
|
||||||
trainable=False,
|
trainable=False,
|
||||||
)
|
)
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, input_coords, input_shape=None):
|
def call(self, input_coords, input_shape=None):
|
||||||
"""Positionally encode points that are normalized to [0,1]."""
|
"""Positionally encode points that are normalized to [0,1]."""
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
|
|||||||
# Build in the build() method to make sure the names are right
|
# Build in the build() method to make sure the names are right
|
||||||
initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value)
|
initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value)
|
||||||
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
|
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -591,7 +592,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
|
|||||||
if text_model.name != "text_model":
|
if text_model.name != "text_model":
|
||||||
raise ValueError("text model must be created with the name `text_model`.")
|
raise ValueError("text model must be created with the name `text_model`.")
|
||||||
|
|
||||||
model(model.dummy_inputs) # Ensure model is fully built
|
model.build() # Ensure model is fully built
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -966,11 +966,8 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
||||||
# make sure channels are last
|
# make sure channels are last
|
||||||
pixel_values = tf.cond(
|
if shape_list(pixel_values)[1] == num_channels:
|
||||||
tf.math.equal(shape_list(pixel_values)[1], num_channels),
|
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
||||||
lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)),
|
|
||||||
lambda: pixel_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# sanity checks
|
# sanity checks
|
||||||
tf.debugging.assert_equal(
|
tf.debugging.assert_equal(
|
||||||
|
|||||||
@@ -766,11 +766,12 @@ class TFWhisperDecoder(tf.keras.layers.Layer):
|
|||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
batch_size, seq_len = input_shape[0], input_shape[1]
|
batch_size, seq_len = input_shape[0], input_shape[1]
|
||||||
|
|
||||||
combined_attention_mask = tf.cond(
|
if seq_len > 1:
|
||||||
tf.math.greater(seq_len, 1),
|
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||||
lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length),
|
else:
|
||||||
lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len),
|
combined_attention_mask = _expand_mask(
|
||||||
)
|
tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len
|
||||||
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
|||||||
@@ -476,6 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
self.mask_emb = self.add_weight(
|
self.mask_emb = self.add_weight(
|
||||||
shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
|
shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
|
||||||
)
|
)
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||||||
old_total_size = config.vocab_size
|
old_total_size = config.vocab_size
|
||||||
new_total_size = old_total_size + new_tokens_size
|
new_total_size = old_total_size + new_tokens_size
|
||||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||||
model(model.dummy_inputs) # builds the embeddings layer
|
model.build()
|
||||||
model.resize_token_embeddings(new_total_size)
|
model.resize_token_embeddings(new_total_size)
|
||||||
|
|
||||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||||
|
|||||||
@@ -1070,9 +1070,9 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
|||||||
|
|
||||||
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
|
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
|
||||||
encoder = TFBertModel(config.encoder)
|
encoder = TFBertModel(config.encoder)
|
||||||
encoder(encoder.dummy_inputs)
|
encoder.build()
|
||||||
decoder = TFBertLMHeadModel(config.decoder)
|
decoder = TFBertLMHeadModel(config.decoder)
|
||||||
decoder(decoder.dummy_inputs)
|
decoder.build()
|
||||||
|
|
||||||
encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
|||||||
@@ -463,7 +463,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||||
|
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class TFOPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
else:
|
else:
|
||||||
# Here we build the word embeddings weights if not exists.
|
# Here we build the word embeddings weights if not exists.
|
||||||
# And then we retry to get the attribute once built.
|
# And then we retry to get the attribute once built.
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
if hasattr(embedding_layer, "weight"):
|
if hasattr(embedding_layer, "weight"):
|
||||||
return embedding_layer.weight
|
return embedding_layer.weight
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -729,9 +729,9 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
|||||||
|
|
||||||
# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
|
# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
|
||||||
encoder = TFViTModel(config.encoder)
|
encoder = TFViTModel(config.encoder)
|
||||||
encoder(encoder.dummy_inputs)
|
encoder.build()
|
||||||
decoder = TFGPT2LMHeadModel(config.decoder)
|
decoder = TFGPT2LMHeadModel(config.decoder)
|
||||||
decoder(decoder.dummy_inputs)
|
decoder.build()
|
||||||
|
|
||||||
encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||||
|
|
||||||
|
|||||||
@@ -281,7 +281,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname, saved_model=False)
|
model.save_pretrained(tmpdirname, saved_model=False)
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
with tf.Graph().as_default() as g:
|
with tf.Graph().as_default() as g:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
for op in g.get_operations():
|
for op in g.get_operations():
|
||||||
model_op_names.add(op.node_def.op)
|
model_op_names.add(op.node_def.op)
|
||||||
@@ -375,7 +375,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||||
|
|
||||||
@@ -1180,7 +1180,7 @@ class TFModelTesterMixin:
|
|||||||
def _get_word_embedding_weight(model, embedding_layer):
|
def _get_word_embedding_weight(model, embedding_layer):
|
||||||
if isinstance(embedding_layer, tf.keras.layers.Embedding):
|
if isinstance(embedding_layer, tf.keras.layers.Embedding):
|
||||||
# builds the embeddings layer
|
# builds the embeddings layer
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
return embedding_layer.embeddings
|
return embedding_layer.embeddings
|
||||||
else:
|
else:
|
||||||
return model._get_word_embedding_weight(embedding_layer)
|
return model._get_word_embedding_weight(embedding_layer)
|
||||||
@@ -1243,7 +1243,7 @@ class TFModelTesterMixin:
|
|||||||
old_total_size = config.vocab_size
|
old_total_size = config.vocab_size
|
||||||
new_total_size = old_total_size + new_tokens_size
|
new_total_size = old_total_size + new_tokens_size
|
||||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||||
model(model.dummy_inputs) # builds the embeddings layer
|
model.build()
|
||||||
model.resize_token_embeddings(new_total_size)
|
model.resize_token_embeddings(new_total_size)
|
||||||
|
|
||||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||||
@@ -2313,8 +2313,8 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
# Finally, check the model can be reloaded
|
# Finally, check the model can be reloaded
|
||||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
model(model.dummy_inputs)
|
model.build()
|
||||||
new_model(model.dummy_inputs)
|
new_model.build()
|
||||||
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
@@ -2440,7 +2440,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
model = TFBertModel(config)
|
model = TFBertModel(config)
|
||||||
# Make sure model is properly initialized
|
# Make sure model is properly initialized
|
||||||
_ = model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
logger = logging.get_logger("transformers.utils.hub")
|
logger = logging.get_logger("transformers.utils.hub")
|
||||||
@@ -2509,7 +2509,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
model = TFBertModel(config)
|
model = TFBertModel(config)
|
||||||
# Make sure model is properly initialized
|
# Make sure model is properly initialized
|
||||||
_ = model(model.dummy_inputs)
|
model.build()
|
||||||
|
|
||||||
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
|
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user