Multilingual BART - (#3602)
- support mbart-en-ro weights - add MBartTokenizer
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch BART model, ported from the fairseq repo."""
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -35,6 +36,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin",
|
||||
}
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
@@ -180,6 +182,7 @@ class EncoderLayer(nn.Module):
|
||||
self.self_attn = SelfAttention(
|
||||
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
|
||||
)
|
||||
self.normalize_before = config.normalize_before
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@@ -201,20 +204,26 @@ class EncoderLayer(nn.Module):
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return x, attn_weights
|
||||
|
||||
|
||||
@@ -236,6 +245,7 @@ class BartEncoder(nn.Module):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
@@ -244,6 +254,8 @@ class BartEncoder(nn.Module):
|
||||
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None,
|
||||
@@ -267,7 +279,7 @@ class BartEncoder(nn.Module):
|
||||
if attention_mask is not None:
|
||||
attention_mask = invert_mask(attention_mask)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
embed_pos = self.embed_positions(input_ids)
|
||||
x = inputs_embeds + embed_pos
|
||||
x = self.layernorm_embedding(x)
|
||||
@@ -290,6 +302,8 @@ class BartEncoder(nn.Module):
|
||||
if self.output_attentions:
|
||||
all_attentions.append(attn)
|
||||
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
if self.output_hidden_states:
|
||||
encoder_states.append(x)
|
||||
|
||||
@@ -311,6 +325,7 @@ class DecoderLayer(nn.Module):
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = SelfAttention(
|
||||
@@ -337,21 +352,28 @@ class DecoderLayer(nn.Module):
|
||||
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
# Self Attention
|
||||
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
layer_state=layer_state,
|
||||
layer_state=layer_state, # adds keys to layer state
|
||||
key_padding_mask=decoder_padding_mask,
|
||||
attn_mask=causal_mask,
|
||||
need_weights=self.output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
# Cross attention
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
x, _ = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
@@ -360,16 +382,20 @@ class DecoderLayer(nn.Module):
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
# Fully Connected
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return (
|
||||
x,
|
||||
self_attn_weights,
|
||||
@@ -394,6 +420,7 @@ class BartDecoder(nn.Module):
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
@@ -402,6 +429,7 @@ class BartDecoder(nn.Module):
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -444,9 +472,8 @@ class BartDecoder(nn.Module):
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
|
||||
x = self.layernorm_embedding(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
@@ -458,14 +485,16 @@ class BartDecoder(nn.Module):
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
next_decoder_cache = []
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
decoder_layer # type: DecoderLayer
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
@@ -477,12 +506,13 @@ class BartDecoder(nn.Module):
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
|
||||
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
|
||||
x = self.layer_norm(x)
|
||||
if self.output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
Reference in New Issue
Block a user