Model templates (#10072)
This commit is contained in:
@@ -161,7 +161,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
|||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
self.rsqrt_att_head_size = 1.0 / math.sqrt(self.attention_head_size)
|
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
||||||
|
|
||||||
self.query = tf.keras.layers.Dense(
|
self.query = tf.keras.layers.Dense(
|
||||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||||
@@ -201,8 +201,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
|||||||
# attention scores.
|
# attention scores.
|
||||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||||
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
||||||
dk = tf.cast(self.rsqrt_att_head_size, dtype=attention_scores.dtype)
|
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
||||||
attention_scores = tf.multiply(attention_scores, dk)
|
attention_scores = tf.divide(attention_scores, dk)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function)
|
# Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function)
|
||||||
|
|||||||
@@ -593,7 +593,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertPredictionHead with Bert->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->{{cookiecutter.camelcase_modelname}}
|
||||||
class {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module):
|
class {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Reference in New Issue
Block a user