Restore TF embeddings and attention layers to their previous version (#9890)
* Refacto BERT * Restore all the concerned models * Remove print * Update template * Apply Sylvain's and Morgan's comments * Fix cast * Put the cast inside call * Remove cond in ebds * Fix funnel * Restore previous dot product (attention_scores) computation * Add ConvBERT and BART * Make all the S2S models ONNX compliant * Fix test * Fix check copies
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -68,122 +69,6 @@ TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
|
||||
class TF{{cookiecutter.camelcase_modelname}}WordEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
config = {
|
||||
"vocab_size": self.vocab_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
|
||||
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
|
||||
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertTokenTypeEmbeddings
|
||||
class TF{{cookiecutter.camelcase_modelname}}TokenTypeEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
self.token_type_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.type_vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
config = {
|
||||
"type_vocab_size": self.type_vocab_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
|
||||
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
|
||||
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
|
||||
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPositionEmbeddings
|
||||
class TF{{cookiecutter.camelcase_modelname}}PositionEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
config = {
|
||||
"max_position_embeddings": self.max_position_embeddings,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
|
||||
input_shape = shape_list(position_ids)
|
||||
position_embeddings = self.position_embeddings[: input_shape[1], :]
|
||||
|
||||
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
@@ -191,34 +76,45 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.word_embeddings = TF{{cookiecutter.camelcase_modelname}}WordEmbeddings(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="word_embeddings",
|
||||
)
|
||||
self.position_embeddings = TF{{cookiecutter.camelcase_modelname}}PositionEmbeddings(
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="position_embeddings",
|
||||
)
|
||||
self.token_type_embeddings = TF{{cookiecutter.camelcase_modelname}}TokenTypeEmbeddings(
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="token_type_embeddings",
|
||||
)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.type_vocab_size = config.type_vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.initializer_range = config.initializer_range
|
||||
self.embeddings_sum = tf.keras.layers.Add()
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
with tf.name_scope("word_embeddings"):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("token_type_embeddings"):
|
||||
self.token_type_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.type_vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("position_embeddings"):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids: tf.Tensor,
|
||||
position_ids: tf.Tensor,
|
||||
token_type_ids: tf.Tensor,
|
||||
inputs_embeds: tf.Tensor,
|
||||
input_ids: tf.Tensor = None,
|
||||
position_ids: tf.Tensor = None,
|
||||
token_type_ids: tf.Tensor = None,
|
||||
inputs_embeds: tf.Tensor = None,
|
||||
training: bool = False,
|
||||
) -> tf.Tensor:
|
||||
"""
|
||||
@@ -230,18 +126,19 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
||||
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
if token_type_ids is None:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_embeds = self.position_embeddings(inputs_embeds)
|
||||
else:
|
||||
position_embeds = self.position_embeddings(position_ids)
|
||||
position_ids = tf.range(start=0, limit=input_shape[-1])[tf.newaxis, :]
|
||||
|
||||
token_type_embeds = self.token_type_embeddings(token_type_ids)
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
||||
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
|
||||
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
||||
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
||||
@@ -261,31 +158,29 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.rsqrt_att_head_size = 1.0 / math.sqrt(self.attention_head_size)
|
||||
|
||||
self.query = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="query",
|
||||
self.query = tf.keras.layers.Dense(
|
||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
self.key = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="key",
|
||||
self.key = tf.keras.layers.Dense(
|
||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
)
|
||||
self.value = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="value",
|
||||
self.value = tf.keras.layers.Dense(
|
||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
||||
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
@@ -294,15 +189,20 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
query_layer = self.query(inputs=hidden_states)
|
||||
key_layer = self.key(inputs=hidden_states)
|
||||
value_layer = self.value(inputs=hidden_states)
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(inputs=hidden_states)
|
||||
mixed_key_layer = self.key(inputs=hidden_states)
|
||||
mixed_value_layer = self.value(inputs=hidden_states)
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw
|
||||
# attention scores.
|
||||
dk = tf.cast(self.attention_head_size, dtype=query_layer.dtype)
|
||||
query_layer = tf.multiply(query_layer, tf.math.rsqrt(dk))
|
||||
attention_scores = tf.einsum("aecd,abcd->acbe", key_layer, query_layer)
|
||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
||||
dk = tf.cast(self.rsqrt_att_head_size, dtype=attention_scores.dtype)
|
||||
attention_scores = tf.multiply(attention_scores, dk)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function)
|
||||
@@ -319,7 +219,11 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
if head_mask is not None:
|
||||
attention_probs = tf.multiply(attention_probs, head_mask)
|
||||
|
||||
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
|
||||
attention_output = tf.matmul(attention_probs, value_layer)
|
||||
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
||||
|
||||
# (batch_size, seq_len_q, all_head_size)
|
||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
||||
return outputs
|
||||
@@ -330,21 +234,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = config.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abcd,cde->abe",
|
||||
output_shape=(None, self.all_head_size),
|
||||
bias_axes="e",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
@@ -396,12 +287,8 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
output_shape=(None, config.intermediate_size),
|
||||
bias_axes="d",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
@@ -418,15 +305,11 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
def __init__(self, config: BertConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
bias_axes="d",
|
||||
output_shape=(None, config.hidden_size),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
@@ -614,12 +497,12 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
|
||||
def get_input_embeddings(self) -> tf.keras.layers.Layer:
|
||||
return self.embeddings.word_embeddings
|
||||
return self.embeddings
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
|
||||
def set_input_embeddings(self, value: tf.Variable):
|
||||
self.embeddings.word_embeddings.weight = value
|
||||
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0]
|
||||
self.embeddings.weight = value
|
||||
self.embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
@@ -917,7 +800,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
)
|
||||
|
||||
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings.word_embeddings, name="mlm___cls")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
|
||||
|
||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||
return self.mlm.predictions
|
||||
@@ -1014,7 +897,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
logger.warning("If you want to use `TF{{cookiecutter.camelcase_modelname}}ForCausalLM` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings.word_embeddings, name="mlm___cls")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
|
||||
|
||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||
return self.mlm.predictions
|
||||
@@ -1662,17 +1545,17 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
||||
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = shape_list(mask)
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
Reference in New Issue
Block a user