[ProphetNet] Fix naming and wrong config (#9514)
* fix naming issues * better names
This commit is contained in:
committed by
GitHub
parent
7f28613213
commit
a051d8928a
@@ -559,7 +559,7 @@ class ProphetNetPreTrainedModel(PreTrainedModel):
|
|||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|
||||||
class ProhpetNetPositionalEmbeddings(nn.Embedding):
|
class ProphetNetPositionalEmbeddings(nn.Embedding):
|
||||||
"""
|
"""
|
||||||
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
|
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
|
||||||
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
|
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
|
||||||
@@ -598,7 +598,7 @@ class ProhpetNetPositionalEmbeddings(nn.Embedding):
|
|||||||
return super().forward(position_ids)
|
return super().forward(position_ids)
|
||||||
|
|
||||||
|
|
||||||
class ProphetNetSelfAttention(nn.Module):
|
class ProphetNetAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -726,7 +726,7 @@ class ProphetNetSelfAttention(nn.Module):
|
|||||||
return attn_output, attn_weights_reshaped
|
return attn_output, attn_weights_reshaped
|
||||||
|
|
||||||
|
|
||||||
class ProhpetNetFeedForward(nn.Module):
|
class ProphetNetFeedForward(nn.Module):
|
||||||
"""
|
"""
|
||||||
This is the residual two feed-forward layer block based on the original Transformer implementation.
|
This is the residual two feed-forward layer block based on the original Transformer implementation.
|
||||||
"""
|
"""
|
||||||
@@ -749,14 +749,14 @@ class ProhpetNetFeedForward(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class ProphetNetNgramProphetNetSelfAttention(nn.Module):
|
class ProphetNetNgramSelfAttention(nn.Module):
|
||||||
def __init__(self, config: ProphetNetConfig):
|
def __init__(self, config: ProphetNetConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.num_buckets = config.num_buckets
|
self.num_buckets = config.num_buckets
|
||||||
self.relative_max_distance = config.relative_max_distance
|
self.relative_max_distance = config.relative_max_distance
|
||||||
self.num_attn_heads = config.num_attention_heads
|
self.num_attn_heads = config.num_decoder_attention_heads
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.head_dim = config.hidden_size // self.num_attn_heads
|
self.head_dim = config.hidden_size // self.num_attn_heads
|
||||||
@@ -1046,11 +1046,11 @@ class ProphetNetEncoderLayer(nn.Module):
|
|||||||
def __init__(self, config: ProphetNetConfig):
|
def __init__(self, config: ProphetNetConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 1st residual block
|
# 1st residual block
|
||||||
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads)
|
self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads)
|
||||||
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
|
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
# 2nd residual block
|
# 2nd residual block
|
||||||
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim)
|
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
|
||||||
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states, attention_mask):
|
||||||
@@ -1075,16 +1075,16 @@ class ProphetNetDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: ProphetNetConfig):
|
def __init__(self, config: ProphetNetConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 1st residual block
|
# 1st residual block
|
||||||
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config)
|
self.self_attn = ProphetNetNgramSelfAttention(config)
|
||||||
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
|
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
# 2nd residual block
|
# 2nd residual block
|
||||||
if config.add_cross_attention:
|
if config.add_cross_attention:
|
||||||
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads)
|
self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads)
|
||||||
self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
|
self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
# 3rd residual block
|
# 3rd residual block
|
||||||
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim)
|
self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim)
|
||||||
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1156,7 +1156,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
|
|||||||
if word_embeddings is not None
|
if word_embeddings is not None
|
||||||
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
)
|
)
|
||||||
self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
|
self.position_embeddings = ProphetNetPositionalEmbeddings(config)
|
||||||
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
|
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
|
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
|
||||||
@@ -1212,7 +1212,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
|
|||||||
# prepare attention mask
|
# prepare attention mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
extended_attention_mask = (
|
extended_attention_mask = (
|
||||||
1.0 - attention_mask[:, None, :].repeat(self.config.num_attention_heads, 1, 1)
|
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
|
||||||
) * -10000.0
|
) * -10000.0
|
||||||
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
|
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
|
||||||
else:
|
else:
|
||||||
@@ -1273,7 +1273,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||||||
if word_embeddings is not None
|
if word_embeddings is not None
|
||||||
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
)
|
)
|
||||||
self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
|
self.position_embeddings = ProphetNetPositionalEmbeddings(config)
|
||||||
|
|
||||||
self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
|
self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
|
||||||
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
|
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
|
||||||
@@ -1397,7 +1397,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||||||
# prepare encoder attention mask
|
# prepare encoder attention mask
|
||||||
if encoder_attention_mask is not None:
|
if encoder_attention_mask is not None:
|
||||||
extended_encoder_attention_mask = (
|
extended_encoder_attention_mask = (
|
||||||
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_attention_heads, 1, 1)
|
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
|
||||||
) * -10000.0
|
) * -10000.0
|
||||||
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
|
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user