Model templates (#10072)

This commit is contained in:
Lysandre Debut
2021-02-08 15:07:02 +01:00
committed by GitHub
parent 3b7e612a5e
commit c9df1b1d53
2 changed files with 4 additions and 4 deletions

View File

@@ -161,7 +161,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.rsqrt_att_head_size = 1.0 / math.sqrt(self.attention_head_size) self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense( self.query = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
@@ -201,8 +201,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
# attention scores. # attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.rsqrt_att_head_size, dtype=attention_scores.dtype) dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.multiply(attention_scores, dk) attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function) # Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function)

View File

@@ -593,7 +593,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
) )
# Copied from transformers.models.bert.modeling_bert.BertPredictionHead with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module): class {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()