Multilingual BART - (#3602)
- support mbart-en-ro weights - add MBartTokenizer
This commit is contained in:
@@ -122,7 +122,7 @@ from .pipelines import (
|
||||
)
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from .tokenization_bart import BartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, MBartTokenizer
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
|
||||
@@ -27,6 +27,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
}
|
||||
|
||||
|
||||
@@ -61,6 +62,9 @@ class BartConfig(PretrainedConfig):
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
normalize_before=False,
|
||||
add_final_layer_norm=False,
|
||||
scale_embedding=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -90,6 +94,11 @@ class BartConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
self.activation_function = activation_function
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
# True for mbart, False otherwise
|
||||
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
|
||||
self.add_final_layer_norm = add_final_layer_norm
|
||||
|
||||
# 3 Types of Dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
@@ -100,9 +109,17 @@ class BartConfig(PretrainedConfig):
|
||||
self.classif_dropout = classifier_dropout
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
def is_valid_mbart(self) -> bool:
|
||||
"""Is the configuration aligned with the MBART paper."""
|
||||
if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
|
||||
return True
|
||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||
return False
|
||||
|
||||
@@ -45,13 +45,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
rename_keys = [
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version", "_float_tensor"]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
@@ -67,6 +78,19 @@ def load_xsum_checkpoint(checkpoint_path):
|
||||
return hub_interface
|
||||
|
||||
|
||||
def convert_checkpoint_from_disk(checkpoint_path, **config_kwargs):
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
remove_ignore_keys_(state_dict)
|
||||
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
mbart_config = BartConfig(vocab_size=vocab_size, **config_kwargs)
|
||||
model = BartForConditionalGeneration(mbart_config)
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = _make_linear_from_emb(model.model.shared)
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
@@ -89,7 +113,7 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
||||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in rename_keys:
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
@@ -118,11 +142,6 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
for k in IGNORE_KEYS:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch BART model, ported from the fairseq repo."""
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -35,6 +36,7 @@ BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin",
|
||||
}
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
@@ -180,6 +182,7 @@ class EncoderLayer(nn.Module):
|
||||
self.self_attn = SelfAttention(
|
||||
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
|
||||
)
|
||||
self.normalize_before = config.normalize_before
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@@ -201,20 +204,26 @@ class EncoderLayer(nn.Module):
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, attn_weights = self.self_attn(
|
||||
query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return x, attn_weights
|
||||
|
||||
|
||||
@@ -236,6 +245,7 @@ class BartEncoder(nn.Module):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
@@ -244,6 +254,8 @@ class BartEncoder(nn.Module):
|
||||
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None,
|
||||
@@ -267,7 +279,7 @@ class BartEncoder(nn.Module):
|
||||
if attention_mask is not None:
|
||||
attention_mask = invert_mask(attention_mask)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
embed_pos = self.embed_positions(input_ids)
|
||||
x = inputs_embeds + embed_pos
|
||||
x = self.layernorm_embedding(x)
|
||||
@@ -290,6 +302,8 @@ class BartEncoder(nn.Module):
|
||||
if self.output_attentions:
|
||||
all_attentions.append(attn)
|
||||
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
if self.output_hidden_states:
|
||||
encoder_states.append(x)
|
||||
|
||||
@@ -311,6 +325,7 @@ class DecoderLayer(nn.Module):
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = SelfAttention(
|
||||
@@ -337,21 +352,28 @@ class DecoderLayer(nn.Module):
|
||||
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
# next line mutates layer state
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
# Self Attention
|
||||
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
layer_state=layer_state,
|
||||
layer_state=layer_state, # adds keys to layer state
|
||||
key_padding_mask=decoder_padding_mask,
|
||||
attn_mask=causal_mask,
|
||||
need_weights=self.output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
# Cross attention
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
x, _ = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
@@ -360,16 +382,20 @@ class DecoderLayer(nn.Module):
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
# Fully Connected
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return (
|
||||
x,
|
||||
self_attn_weights,
|
||||
@@ -394,6 +420,7 @@ class BartDecoder(nn.Module):
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
@@ -402,6 +429,7 @@ class BartDecoder(nn.Module):
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -444,9 +472,8 @@ class BartDecoder(nn.Module):
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
|
||||
x = self.layernorm_embedding(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
@@ -458,14 +485,16 @@ class BartDecoder(nn.Module):
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
next_decoder_cache = []
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
decoder_layer # type: DecoderLayer
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x,
|
||||
encoder_hidden_states,
|
||||
@@ -477,12 +506,13 @@ class BartDecoder(nn.Module):
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
|
||||
if self.layer_norm and (idx == len(self.layers) - 1): # last layer of mbart
|
||||
x = self.layer_norm(x)
|
||||
if self.output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
# Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# vocab and merges same as roberta
|
||||
@@ -21,6 +27,8 @@ vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-v
|
||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_bart_models = ["bart-large", "bart-large-mnli", "bart-large-cnn", "bart-large-xsum"]
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "sentence.bpe.model"}
|
||||
|
||||
|
||||
class BartTokenizer(RobertaTokenizer):
|
||||
# merges and vocab same as Roberta
|
||||
@@ -29,3 +37,13 @@ class BartTokenizer(RobertaTokenizer):
|
||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
|
||||
_all_mbart_models = ["mbart-large-en-ro"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
|
||||
Reference in New Issue
Block a user