Multilingual BART - (#3602)

- support mbart-en-ro weights
- add MBartTokenizer
This commit is contained in:
Sam Shleifer
2020-04-10 11:25:39 -04:00
committed by GitHub
parent f98d0ef2a2
commit 7a7fdf71f8
7 changed files with 232 additions and 38 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import math
import random
from typing import Dict, List, Optional, Tuple
@@ -35,6 +36,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin",
}
BART_START_DOCSTRING = r"""
@@ -180,6 +182,7 @@ class EncoderLayer(nn.Module):
self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
@@ -201,20 +204,26 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, attn_weights
@@ -236,6 +245,7 @@ class BartEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = config.max_position_embeddings
@@ -244,6 +254,8 @@ class BartEncoder(nn.Module):
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim)
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward(
self, input_ids, attention_mask=None,
@@ -267,7 +279,7 @@ class BartEncoder(nn.Module):
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
@@ -290,6 +302,8 @@ class BartEncoder(nn.Module):
if self.output_attentions:
all_attentions.append(attn)
if self.layer_norm:
x = self.layer_norm(x)
if self.output_hidden_states:
encoder_states.append(x)
@@ -311,6 +325,7 @@ class DecoderLayer(nn.Module):
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = SelfAttention(
@@ -337,21 +352,28 @@ class DecoderLayer(nn.Module):
if layer_state is None:
layer_state = {}
# next line mutates layer state
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# Self Attention
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state,
layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
need_weights=self.output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# Cross attention
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
@@ -360,16 +382,20 @@ class DecoderLayer(nn.Module):
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x = self.encoder_attn_layer_norm(x)
# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x)
if not self.normalize_before:
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
@@ -394,6 +420,7 @@ class BartDecoder(nn.Module):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx,
@@ -402,6 +429,7 @@ class BartDecoder(nn.Module):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]
self.layernorm_embedding = LayerNorm(config.d_model)
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
def forward(
self,
@@ -444,9 +472,8 @@ class BartDecoder(nn.Module):
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
x = self.embed_tokens(input_ids)
x = self.embed_tokens(input_ids) * self.embed_scale
x += positions
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
@@ -458,14 +485,16 @@ class BartDecoder(nn.Module):
all_hidden_states = ()
all_self_attns = ()
next_decoder_cache = []
for i, decoder_layer in enumerate(self.layers):
decoder_layer # type: DecoderLayer
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if self.output_hidden_states:
all_hidden_states += (x,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
@@ -477,12 +506,13 @@ class BartDecoder(nn.Module):
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
x = self.layer_norm(x)
if self.output_attentions:
all_self_attns += (layer_self_attn,)
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)