MarianMTModel.from_pretrained('Helsinki-NLP/opus-marian-en-de') (#3908)
Co-Authored-By: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
@@ -44,6 +44,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr
|
||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
@@ -241,6 +242,8 @@ if is_torch_available():
|
||||
BartForConditionalGeneration,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .tokenization_marian import MarianSentencePieceTokenizer
|
||||
from .modeling_roberta import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaModel,
|
||||
|
||||
@@ -65,6 +65,9 @@ class BartConfig(PretrainedConfig):
|
||||
normalize_before=False,
|
||||
add_final_layer_norm=False,
|
||||
scale_embedding=False,
|
||||
normalize_embedding=True,
|
||||
static_position_embeddings=False,
|
||||
add_bias_logits=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -73,6 +76,8 @@ class BartConfig(PretrainedConfig):
|
||||
config = BartConfig.from_pretrained('bart-large')
|
||||
model = BartModel(config)
|
||||
"""
|
||||
if "hidden_size" in common_kwargs:
|
||||
raise ValueError("hidden size is called d_model")
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
pad_token_id=pad_token_id,
|
||||
@@ -94,12 +99,17 @@ class BartConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
self.activation_function = activation_function
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
# True for mbart, False otherwise
|
||||
# Params introduced for Mbart
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.normalize_embedding = normalize_embedding # True for mbart, False otherwise
|
||||
self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
|
||||
self.add_final_layer_norm = add_final_layer_norm
|
||||
|
||||
# Params introduced for Marian
|
||||
self.add_bias_logits = add_bias_logits
|
||||
self.static_position_embeddings = static_position_embeddings
|
||||
|
||||
# 3 Types of Dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
|
||||
26
src/transformers/configuration_marian.py
Normal file
26
src/transformers/configuration_marian.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The OPUS-NMT Team, Marian team, and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Marian model configuration """
|
||||
|
||||
from .configuration_bart import BartConfig
|
||||
|
||||
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"marian-en-de": "https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/config.json",
|
||||
}
|
||||
|
||||
|
||||
class MarianConfig(BartConfig):
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
397
src/transformers/convert_marian_to_pytorch.py
Normal file
397
src/transformers/convert_marian_to_pytorch.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import MarianConfig, MarianMTModel, MarianSentencePieceTokenizer
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str):
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :]
|
||||
return text # or whatever
|
||||
|
||||
|
||||
def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict):
|
||||
sd = {}
|
||||
for k in opus_dict:
|
||||
if not k.startswith(layer_prefix):
|
||||
continue
|
||||
stripped = remove_prefix(k, layer_prefix)
|
||||
v = opus_dict[k].T # besides embeddings, everything must be transposed.
|
||||
sd[converter[stripped]] = torch.tensor(v).squeeze()
|
||||
return sd
|
||||
|
||||
|
||||
def load_layers_(layer_lst: torch.nn.ModuleList, opus_state: dict, converter, is_decoder=False):
|
||||
for i, layer in enumerate(layer_lst):
|
||||
layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_"
|
||||
sd = convert_encoder_layer(opus_state, layer_tag, converter)
|
||||
layer.load_state_dict(sd, strict=True)
|
||||
|
||||
|
||||
def add_emb_entries(wemb, final_bias, n_special_tokens=1):
|
||||
vsize, d_model = wemb.shape
|
||||
embs_to_add = np.zeros((n_special_tokens, d_model))
|
||||
new_embs = np.concatenate([wemb, embs_to_add])
|
||||
bias_to_add = np.zeros((n_special_tokens, 1))
|
||||
new_bias = np.concatenate((final_bias, bias_to_add), axis=1)
|
||||
return new_embs, new_bias
|
||||
|
||||
|
||||
def _cast_yaml_str(v):
|
||||
bool_dct = {"true": True, "false": False}
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
elif v in bool_dct:
|
||||
return bool_dct[v]
|
||||
try:
|
||||
return int(v)
|
||||
except (TypeError, ValueError):
|
||||
return v
|
||||
|
||||
|
||||
def cast_marian_config(raw_cfg: Dict[str, str]) -> Dict:
|
||||
return {k: _cast_yaml_str(v) for k, v in raw_cfg.items()}
|
||||
|
||||
|
||||
CONFIG_KEY = "special:model.yml"
|
||||
|
||||
|
||||
def load_config_from_state_dict(opus_dict):
|
||||
import yaml
|
||||
|
||||
cfg_str = "".join([chr(x) for x in opus_dict[CONFIG_KEY]])
|
||||
yaml_cfg = yaml.load(cfg_str[:-1], Loader=yaml.BaseLoader)
|
||||
return cast_marian_config(yaml_cfg)
|
||||
|
||||
|
||||
def find_model_file(dest_dir): # this one better
|
||||
model_files = list(Path(dest_dir).glob("*.npz"))
|
||||
assert len(model_files) == 1, model_files
|
||||
model_file = model_files[0]
|
||||
return model_file
|
||||
|
||||
|
||||
def parse_readmes(repo_path):
|
||||
results = {}
|
||||
for p in Path(repo_path).ls():
|
||||
n_dash = p.name.count("-")
|
||||
if n_dash == 0:
|
||||
continue
|
||||
else:
|
||||
lns = list(open(p / "README.md").readlines())
|
||||
results[p.name] = _parse_readme(lns)
|
||||
return results
|
||||
|
||||
|
||||
def download_all_sentencepiece_models(repo_path="Opus-MT-train/models"):
|
||||
"""Requires 300GB"""
|
||||
save_dir = Path("marian_ckpt")
|
||||
if not Path(repo_path).exists():
|
||||
raise ValueError("You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git")
|
||||
results: dict = parse_readmes(repo_path)
|
||||
for k, v in tqdm(list(results.items())):
|
||||
if os.path.exists(save_dir / k):
|
||||
print(f"already have path {k}")
|
||||
continue
|
||||
if "SentencePiece" not in v["pre-processing"]:
|
||||
continue
|
||||
download_and_unzip(v["download"], save_dir / k)
|
||||
|
||||
|
||||
def _parse_readme(lns):
|
||||
"""Get link and metadata from opus model card equivalent."""
|
||||
subres = {}
|
||||
for ln in [x.strip() for x in lns]:
|
||||
if not ln.startswith("*"):
|
||||
continue
|
||||
ln = ln[1:].strip()
|
||||
|
||||
for k in ["download", "dataset", "models", "model", "pre-processing"]:
|
||||
if ln.startswith(k):
|
||||
break
|
||||
else:
|
||||
continue
|
||||
if k in ["dataset", "model", "pre-processing"]:
|
||||
splat = ln.split(":")
|
||||
_, v = splat
|
||||
subres[k] = v
|
||||
elif k == "download":
|
||||
v = ln.split("(")[-1][:-1]
|
||||
subres[k] = v
|
||||
return subres
|
||||
|
||||
|
||||
def write_metadata(dest_dir: Path):
|
||||
dname = dest_dir.name.split("-")
|
||||
dct = dict(target_lang=dname[-1], source_lang="-".join(dname[:-1]))
|
||||
save_json(dct, dest_dir / "tokenizer_config.json")
|
||||
|
||||
|
||||
def add_to_vocab_(vocab: Dict[str, int], special_tokens: List[str]):
|
||||
start = max(vocab.values()) + 1
|
||||
added = 0
|
||||
for tok in special_tokens:
|
||||
if tok in vocab:
|
||||
continue
|
||||
vocab[tok] = start + added
|
||||
added += 1
|
||||
return added
|
||||
|
||||
|
||||
def add_special_tokens_to_vocab(model_dir: Path) -> None:
|
||||
vocab = load_yaml(model_dir / "opus.spm32k-spm32k.vocab.yml")
|
||||
vocab = {k: int(v) for k, v in vocab.items()}
|
||||
num_added = add_to_vocab_(vocab, ["<pad>"])
|
||||
print(f"added {num_added} tokens to vocab")
|
||||
save_json(vocab, model_dir / "vocab.json")
|
||||
write_metadata(model_dir)
|
||||
|
||||
|
||||
def save_tokenizer(self, save_directory):
|
||||
dest = Path(save_directory)
|
||||
src_path = Path(self.init_kwargs["source_spm"])
|
||||
|
||||
for dest_name in {"source.spm", "target.spm", "tokenizer_config.json"}:
|
||||
shutil.copyfile(src_path.parent / dest_name, dest / dest_name)
|
||||
save_json(self.encoder, dest / "vocab.json")
|
||||
|
||||
|
||||
def check_equal(marian_cfg, k1, k2):
|
||||
v1, v2 = marian_cfg[k1], marian_cfg[k2]
|
||||
assert v1 == v2, f"hparams {k1},{k2} differ: {v1} != {v2}"
|
||||
|
||||
|
||||
def check_marian_cfg_assumptions(marian_cfg):
|
||||
assumed_settings = {
|
||||
"tied-embeddings-all": True,
|
||||
"layer-normalization": False,
|
||||
"right-left": False,
|
||||
"transformer-ffn-depth": 2,
|
||||
"transformer-aan-depth": 2,
|
||||
"transformer-no-projection": False,
|
||||
"transformer-postprocess-emb": "d",
|
||||
"transformer-postprocess": "dan", # Dropout, add, normalize
|
||||
"transformer-preprocess": "",
|
||||
"type": "transformer",
|
||||
"ulr-dim-emb": 0,
|
||||
"dec-cell-base-depth": 2,
|
||||
"dec-cell-high-depth": 1,
|
||||
"transformer-aan-nogate": False,
|
||||
}
|
||||
for k, v in assumed_settings.items():
|
||||
actual = marian_cfg[k]
|
||||
assert actual == v, f"Unexpected config value for {k} expected {v} got {actual}"
|
||||
check_equal(marian_cfg, "transformer-ffn-activation", "transformer-aan-activation")
|
||||
check_equal(marian_cfg, "transformer-ffn-depth", "transformer-aan-depth")
|
||||
check_equal(marian_cfg, "transformer-dim-ffn", "transformer-dim-aan")
|
||||
|
||||
|
||||
BIAS_KEY = "decoder_ff_logit_out_b"
|
||||
BART_CONVERTER = { # for each encoder and decoder layer
|
||||
"self_Wq": "self_attn.q_proj.weight",
|
||||
"self_Wk": "self_attn.k_proj.weight",
|
||||
"self_Wv": "self_attn.v_proj.weight",
|
||||
"self_Wo": "self_attn.out_proj.weight",
|
||||
"self_bq": "self_attn.q_proj.bias",
|
||||
"self_bk": "self_attn.k_proj.bias",
|
||||
"self_bv": "self_attn.v_proj.bias",
|
||||
"self_bo": "self_attn.out_proj.bias",
|
||||
"self_Wo_ln_scale": "self_attn_layer_norm.weight",
|
||||
"self_Wo_ln_bias": "self_attn_layer_norm.bias",
|
||||
"ffn_W1": "fc1.weight",
|
||||
"ffn_b1": "fc1.bias",
|
||||
"ffn_W2": "fc2.weight",
|
||||
"ffn_b2": "fc2.bias",
|
||||
"ffn_ffn_ln_scale": "final_layer_norm.weight",
|
||||
"ffn_ffn_ln_bias": "final_layer_norm.bias",
|
||||
# Decoder Cross Attention
|
||||
"context_Wk": "encoder_attn.k_proj.weight",
|
||||
"context_Wo": "encoder_attn.out_proj.weight",
|
||||
"context_Wq": "encoder_attn.q_proj.weight",
|
||||
"context_Wv": "encoder_attn.v_proj.weight",
|
||||
"context_bk": "encoder_attn.k_proj.bias",
|
||||
"context_bo": "encoder_attn.out_proj.bias",
|
||||
"context_bq": "encoder_attn.q_proj.bias",
|
||||
"context_bv": "encoder_attn.v_proj.bias",
|
||||
"context_Wo_ln_scale": "encoder_attn_layer_norm.weight",
|
||||
"context_Wo_ln_bias": "encoder_attn_layer_norm.bias",
|
||||
}
|
||||
|
||||
|
||||
class OpusState:
|
||||
def __init__(self, source_dir):
|
||||
npz_path = find_model_file(source_dir)
|
||||
self.state_dict = np.load(npz_path)
|
||||
cfg = load_config_from_state_dict(self.state_dict)
|
||||
assert cfg["dim-vocabs"][0] == cfg["dim-vocabs"][1]
|
||||
assert "Wpos" not in self.state_dict
|
||||
self.state_dict = dict(self.state_dict)
|
||||
self.wemb, self.final_bias = add_emb_entries(self.state_dict["Wemb"], self.state_dict[BIAS_KEY], 1)
|
||||
self.pad_token_id = self.wemb.shape[0] - 1
|
||||
cfg["vocab_size"] = self.pad_token_id + 1
|
||||
# self.state_dict['Wemb'].sha
|
||||
self.state_keys = list(self.state_dict.keys())
|
||||
if "Wtype" in self.state_dict:
|
||||
raise ValueError("found Wtype key")
|
||||
self._check_layer_entries()
|
||||
self.source_dir = source_dir
|
||||
self.cfg = cfg
|
||||
hidden_size, intermediate_shape = self.state_dict["encoder_l1_ffn_W1"].shape
|
||||
assert hidden_size == cfg["dim-emb"] == 512
|
||||
|
||||
# Process decoder.yml
|
||||
decoder_yml = cast_marian_config(load_yaml(source_dir / "decoder.yml"))
|
||||
# TODO: what are normalize and word-penalty?
|
||||
check_marian_cfg_assumptions(cfg)
|
||||
self.hf_config = MarianConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
decoder_layers=cfg["dec-depth"],
|
||||
encoder_layers=cfg["enc-depth"],
|
||||
decoder_attention_heads=cfg["transformer-heads"],
|
||||
encoder_attention_heads=cfg["transformer-heads"],
|
||||
decoder_ffn_dim=cfg["transformer-dim-ffn"],
|
||||
encoder_ffn_dim=cfg["transformer-dim-ffn"],
|
||||
d_model=cfg["dim-emb"],
|
||||
activation_function=cfg["transformer-aan-activation"],
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=0,
|
||||
bos_token_id=0,
|
||||
max_position_embeddings=cfg["dim-emb"],
|
||||
scale_embedding=True,
|
||||
normalize_embedding="n" in cfg["transformer-preprocess"],
|
||||
static_position_embeddings=not cfg["transformer-train-position-embeddings"],
|
||||
dropout=0.1, # see opus-mt-train repo/transformer-dropout param.
|
||||
# default: add_final_layer_norm=False,
|
||||
num_beams=decoder_yml["beam-size"],
|
||||
)
|
||||
|
||||
def _check_layer_entries(self):
|
||||
self.encoder_l1 = self.sub_keys("encoder_l1")
|
||||
self.decoder_l1 = self.sub_keys("decoder_l1")
|
||||
self.decoder_l2 = self.sub_keys("decoder_l2")
|
||||
if len(self.encoder_l1) != 16:
|
||||
warnings.warn(f"Expected 16 keys for each encoder layer, got {len(self.encoder_l1)}")
|
||||
if len(self.decoder_l1) != 26:
|
||||
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
|
||||
if len(self.decoder_l2) != 26:
|
||||
warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
|
||||
|
||||
@property
|
||||
def extra_keys(self):
|
||||
extra = []
|
||||
for k in self.state_keys:
|
||||
if (
|
||||
k.startswith("encoder_l")
|
||||
or k.startswith("decoder_l")
|
||||
or k in [CONFIG_KEY, "Wemb", "Wpos", "decoder_ff_logit_out_b"]
|
||||
):
|
||||
continue
|
||||
else:
|
||||
extra.append(k)
|
||||
return extra
|
||||
|
||||
def sub_keys(self, layer_prefix):
|
||||
return [remove_prefix(k, layer_prefix) for k in self.state_dict if k.startswith(layer_prefix)]
|
||||
|
||||
def load_marian_model(self) -> MarianMTModel:
|
||||
state_dict, cfg = self.state_dict, self.hf_config
|
||||
|
||||
assert cfg.static_position_embeddings
|
||||
model = MarianMTModel(cfg)
|
||||
|
||||
assert "hidden_size" not in cfg.to_dict()
|
||||
load_layers_(
|
||||
model.model.encoder.layers, state_dict, BART_CONVERTER,
|
||||
)
|
||||
load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True)
|
||||
|
||||
# handle tensors not associated with layers
|
||||
wemb_tensor = torch.nn.Parameter(torch.FloatTensor(self.wemb))
|
||||
bias_tensor = torch.nn.Parameter(torch.FloatTensor(self.final_bias))
|
||||
model.model.shared.weight = wemb_tensor
|
||||
model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared
|
||||
|
||||
model.final_logits_bias = bias_tensor
|
||||
|
||||
if "Wpos" in state_dict:
|
||||
print("Unexpected: got Wpos")
|
||||
wpos_tensor = torch.tensor(state_dict["Wpos"])
|
||||
model.model.encoder.embed_positions.weight = wpos_tensor
|
||||
model.model.decoder.embed_positions.weight = wpos_tensor
|
||||
|
||||
if cfg.normalize_embedding:
|
||||
assert "encoder_emb_ln_scale_pre" in state_dict
|
||||
raise NotImplementedError("Need to convert layernorm_embedding")
|
||||
|
||||
assert not self.extra_keys, f"Failed to convert {self.extra_keys}"
|
||||
assert model.model.shared.padding_idx == self.pad_token_id
|
||||
return model
|
||||
|
||||
|
||||
def download_and_unzip(url, dest_dir):
|
||||
try:
|
||||
import wget
|
||||
except ImportError:
|
||||
raise ImportError("you must pip install wget")
|
||||
|
||||
filename = wget.download(url)
|
||||
unzip(filename, dest_dir)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def main(source_dir, dest_dir):
|
||||
dest_dir = Path(dest_dir)
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
|
||||
add_special_tokens_to_vocab(source_dir)
|
||||
tokenizer = MarianSentencePieceTokenizer.from_pretrained(str(source_dir))
|
||||
save_tokenizer(tokenizer, dest_dir)
|
||||
|
||||
opus_state = OpusState(source_dir)
|
||||
assert opus_state.cfg["vocab_size"] == len(tokenizer.encoder)
|
||||
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
|
||||
# ^^ Save human readable marian config for debugging
|
||||
|
||||
model = opus_state.load_marian_model()
|
||||
model.save_pretrained(dest_dir)
|
||||
model.from_pretrained(dest_dir) # sanity check
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src", type=str, help="path to marian model dir", default="en-de")
|
||||
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
source_dir = Path(args.src)
|
||||
assert source_dir.exists()
|
||||
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
|
||||
main(source_dir, dest_dir)
|
||||
|
||||
|
||||
def load_yaml(path):
|
||||
import yaml
|
||||
|
||||
with open(path) as f:
|
||||
return yaml.load(f, Loader=yaml.BaseLoader)
|
||||
|
||||
|
||||
def save_json(content: Union[Dict, List], path: str) -> None:
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f)
|
||||
|
||||
|
||||
def unzip(zip_path: str, dest_dir: str) -> None:
|
||||
with ZipFile(zip_path, "r") as zipObj:
|
||||
zipObj.extractall(dest_dir)
|
||||
@@ -18,6 +18,7 @@ import math
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
@@ -125,7 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
elif isinstance(module, SinusoidalPositionalEmbedding):
|
||||
pass
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
@@ -250,10 +253,16 @@ class BartEncoder(nn.Module):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
|
||||
if config.static_position_embeddings:
|
||||
self.embed_positions = SinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx
|
||||
)
|
||||
else:
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
||||
|
||||
@@ -422,13 +431,18 @@ class BartDecoder(nn.Module):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
)
|
||||
if config.static_position_embeddings:
|
||||
self.embed_positions = SinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, config.pad_token_id
|
||||
)
|
||||
else:
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
) # type: List[DecoderLayer]
|
||||
self.layernorm_embedding = LayerNorm(config.d_model)
|
||||
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
def forward(
|
||||
@@ -470,7 +484,7 @@ class BartDecoder(nn.Module):
|
||||
if use_cache:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
# assert input_ids.ne(self.padding_idx).any()
|
||||
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
@@ -859,6 +873,22 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
super().__init__(config)
|
||||
base_model = BartModel(config)
|
||||
self.model = base_model
|
||||
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
old_num_tokens = self.model.shared.num_embeddings
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self.model.shared = new_embeddings
|
||||
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
|
||||
if new_num_tokens <= old_num_tokens:
|
||||
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
||||
else:
|
||||
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens))
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -923,8 +953,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
# TODO(SS): do we need to ignore pad tokens in lm_labels?
|
||||
@@ -957,6 +987,18 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
all_but_token_ids_mask = torch.tensor(
|
||||
[x for x in range(self.config.vocab_size) if x not in token_ids],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
((enc_out, enc_mask), decoder_cached_states) = past
|
||||
@@ -1061,3 +1103,39 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions, embedding_dim, padding_idx=None):
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
if embedding_dim % 2 != 0:
|
||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter):
|
||||
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
|
||||
The cos features are in the 2nd half of the vector. [dim // 2:]
|
||||
"""
|
||||
n_pos, dim = out.shape
|
||||
position_enc = np.array(
|
||||
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
||||
)
|
||||
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
|
||||
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
||||
out.detach_()
|
||||
out.requires_grad = False
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids, use_cache=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids.shape[:2]
|
||||
if use_cache:
|
||||
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
|
||||
else:
|
||||
# starts at 0, ends at 1-seq_len
|
||||
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
||||
return super().forward(positions)
|
||||
|
||||
35
src/transformers/modeling_marian.py
Normal file
35
src/transformers/modeling_marian.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Marian Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch MarianMTModel model, ported from the Marian C++ repo."""
|
||||
|
||||
|
||||
from transformers.modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"opus-mt-en-de": "https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
class MarianMTModel(BartForConditionalGeneration):
|
||||
"""Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
|
||||
Model API is identical to BartForConditionalGeneration"""
|
||||
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
@@ -1530,18 +1530,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
return decoded
|
||||
|
||||
# force one of token_ids to be generated by setting prob of all other tokens to 0.
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
all_but_token_ids_mask = torch.tensor(
|
||||
[x for x in range(self.config.vocab_size) if x not in token_ids],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
|
||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
||||
|
||||
160
src/transformers/tokenization_marian.py
Normal file
160
src/transformers/tokenization_marian.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import json
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import sentencepiece
|
||||
|
||||
from .file_utils import S3_BUCKET_PREFIX
|
||||
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
|
||||
vocab_files_names = {
|
||||
"source_spm": "source.spm",
|
||||
"target_spm": "target.spm",
|
||||
"vocab": "vocab.json",
|
||||
"tokenizer_config_file": "tokenizer_config.json",
|
||||
}
|
||||
MODEL_NAMES = ("opus-mt-en-de",)
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
|
||||
for k, fname in vocab_files_names.items()
|
||||
}
|
||||
# Example URL https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/vocab.json
|
||||
|
||||
|
||||
class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = vocab_files_names
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
|
||||
model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab=None,
|
||||
source_spm=None,
|
||||
target_spm=None,
|
||||
source_lang=None,
|
||||
target_lang=None,
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
pad_token="<pad>",
|
||||
max_len=512,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
# bos_token=bos_token,
|
||||
max_len=max_len,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
)
|
||||
self.encoder = load_json(vocab)
|
||||
assert self.pad_token in self.encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
self.source_lang = source_lang
|
||||
self.target_lang = target_lang
|
||||
|
||||
# load SentencePiece model for pre-processing
|
||||
self.paths = {}
|
||||
|
||||
self.spm_source = sentencepiece.SentencePieceProcessor()
|
||||
self.spm_source.Load(source_spm)
|
||||
|
||||
self.spm_target = sentencepiece.SentencePieceProcessor()
|
||||
self.spm_target.Load(target_spm)
|
||||
|
||||
# Note(SS): splitter would require lots of book-keeping.
|
||||
# self.sentence_splitter = MosesSentenceSplitter(source_lang)
|
||||
try:
|
||||
from mosestokenizer import MosesPunctuationNormalizer
|
||||
|
||||
self.punc_normalizer = MosesPunctuationNormalizer(source_lang)
|
||||
except ImportError:
|
||||
warnings.warn("Recommended: pip install mosestokenizer")
|
||||
self.punc_normalizer = lambda x: x
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
return self.encoder[token]
|
||||
|
||||
def _tokenize(self, text: str, src=True) -> List[str]:
|
||||
spm = self.spm_source if src else self.spm_target
|
||||
return spm.EncodeAsPieces(text)
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the encoder."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Uses target language sentencepiece model"""
|
||||
return self.spm_target.DecodePieces(tokens)
|
||||
|
||||
def _append_special_tokens_and_truncate(self, tokens: str, max_length: int,) -> List[int]:
|
||||
ids: list = self.convert_tokens_to_ids(tokens)[:max_length]
|
||||
return ids + [self.eos_token_id]
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def decode_batch(self, token_ids, **kwargs) -> List[str]:
|
||||
return [self.decode(ids, **kwargs) for ids in token_ids]
|
||||
|
||||
def prepare_translation_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_to_max_length: bool = True,
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Arguments:
|
||||
src_texts: list of src language texts
|
||||
src_lang: default en_XX (english)
|
||||
tgt_texts: list of tgt language texts
|
||||
tgt_lang: default ro_RO (romanian)
|
||||
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
||||
pad_to_max_length: (bool)
|
||||
|
||||
Returns:
|
||||
BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
|
||||
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists)
|
||||
|
||||
Examples:
|
||||
from transformers import MarianS
|
||||
"""
|
||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
src=True,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
|
||||
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
src=False,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
return model_inputs
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.encoder)
|
||||
|
||||
|
||||
def load_json(path: str) -> Union[Dict, List]:
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
@@ -31,7 +31,7 @@ from tokenizers import Encoding as EncodingFast
|
||||
from tokenizers.decoders import Decoder as DecoderFast
|
||||
from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast
|
||||
|
||||
from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available
|
||||
from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available, torch_required
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@@ -458,6 +458,12 @@ class BatchEncoding(UserDict):
|
||||
char_index = batch_or_char_index
|
||||
return self._encodings[batch_index].char_to_word(char_index)
|
||||
|
||||
@torch_required
|
||||
def to(self, device: str):
|
||||
"""Send all values to device by calling v.to(device)"""
|
||||
self.data = {k: v.to(device) for k, v in self.data.items()}
|
||||
return self
|
||||
|
||||
|
||||
class SpecialTokensMixin:
|
||||
""" SpecialTokensMixin is derived by ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` and
|
||||
|
||||
Reference in New Issue
Block a user