[Marian] documentation and AutoModel support (#4152)
- MarianSentencepieceTokenizer - > MarianTokenizer - Start using unk token. - add docs page - add better generation params to MarianConfig - more conversion utilities
This commit is contained in:
@@ -164,8 +164,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
|
|||||||
17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
17. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||||
18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
18. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||||
19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
19. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||||
20. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
|
20. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
||||||
21. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
21. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
|
||||||
|
22. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||||
|
|
||||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
||||||
|
|
||||||
|
|||||||
@@ -108,3 +108,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
|
|||||||
model_doc/electra
|
model_doc/electra
|
||||||
model_doc/dialogpt
|
model_doc/dialogpt
|
||||||
model_doc/reformer
|
model_doc/reformer
|
||||||
|
model_doc/marian
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
Bart
|
Bart
|
||||||
----------------------------------------------------
|
----------------------------------------------------
|
||||||
**DISCLAIMER:** This model is still a work in progress, if you see something strange,
|
**DISCLAIMER:** If you see something strange,
|
||||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||||
@sshleifer
|
@sshleifer
|
||||||
|
|
||||||
|
|||||||
43
docs/source/model_doc/marian.rst
Normal file
43
docs/source/model_doc/marian.rst
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
MarianMTModel
|
||||||
|
----------------------------------------------------
|
||||||
|
**DISCLAIMER:** If you see something strange,
|
||||||
|
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||||
|
@sshleifer
|
||||||
|
These models are for machine translation. The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
|
||||||
|
|
||||||
|
Opus Project
|
||||||
|
~~~~~~~~~~~~
|
||||||
|
The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
|
||||||
|
All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
|
||||||
|
|
||||||
|
Implementation Notes
|
||||||
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
- each model is about 298 MB on disk, there are 1,000+ models.
|
||||||
|
- Models are named with the following patter 'Helsinki-NLP/opus-mt-{src_langs}-{targ_langs}'. If there are multiple source or target languages they are joined by a '+' symbol.
|
||||||
|
- the 80 opus models that require BPE preprocessing are not supported.
|
||||||
|
- There is an outstanding issue w.r.t multilingual models and language codes.
|
||||||
|
- The modeling code is the same as ``BartModel`` with a few minor modifications:
|
||||||
|
- static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``)
|
||||||
|
- a new final_logits_bias (``MarianConfig.add_bias_logits=True``)
|
||||||
|
- no layernorm_embedding (``MarianConfig.normalize_embedding=False``)
|
||||||
|
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>)
|
||||||
|
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``
|
||||||
|
|
||||||
|
|
||||||
|
MarianMTModel
|
||||||
|
~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
|
||||||
|
Model API is identical to BartForConditionalGeneration.
|
||||||
|
Available models are listed at `Model List <https://huggingface.co/models?search=Helsinki-NLP>`__
|
||||||
|
This class inherits all functionality from ``BartForConditionalGeneration``, see that page for method signatures.
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MarianMTModel
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MarianTokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MarianTokenizer
|
||||||
|
:members: prepare_translation_batch
|
||||||
@@ -275,7 +275,7 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||||||
| | | | FlauBERT large architecture |
|
| | | | FlauBERT large architecture |
|
||||||
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| Bart | ``bart-large`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters |
|
| Bart | ``bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
|
||||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
||||||
@@ -299,3 +299,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||||||
| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters |
|
| Reformer | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters |
|
||||||
| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky |
|
| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
|
||||||
|
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_ |
|
||||||
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ if is_torch_available():
|
|||||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
)
|
)
|
||||||
from .modeling_marian import MarianMTModel
|
from .modeling_marian import MarianMTModel
|
||||||
from .tokenization_marian import MarianSentencePieceTokenizer
|
from .tokenization_marian import MarianTokenizer
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, Electr
|
|||||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||||
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||||
|
from .configuration_marian import MarianConfig
|
||||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
from .configuration_reformer import ReformerConfig
|
from .configuration_reformer import ReformerConfig
|
||||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||||
@@ -73,6 +74,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||||||
("albert", AlbertConfig,),
|
("albert", AlbertConfig,),
|
||||||
("camembert", CamembertConfig,),
|
("camembert", CamembertConfig,),
|
||||||
("xlm-roberta", XLMRobertaConfig,),
|
("xlm-roberta", XLMRobertaConfig,),
|
||||||
|
("marian", MarianConfig,),
|
||||||
("bart", BartConfig,),
|
("bart", BartConfig,),
|
||||||
("reformer", ReformerConfig,),
|
("reformer", ReformerConfig,),
|
||||||
("roberta", RobertaConfig,),
|
("roberta", RobertaConfig,),
|
||||||
|
|||||||
@@ -23,4 +23,5 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
|
|
||||||
|
|
||||||
class MarianConfig(BartConfig):
|
class MarianConfig(BartConfig):
|
||||||
|
model_type = "marian"
|
||||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import MarianConfig, MarianMTModel, MarianSentencePieceTokenizer
|
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
||||||
|
from transformers.hf_api import HfApi
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(text: str, prefix: str):
|
def remove_prefix(text: str, prefix: str):
|
||||||
@@ -38,6 +39,19 @@ def load_layers_(layer_lst: torch.nn.ModuleList, opus_state: dict, converter, is
|
|||||||
layer.load_state_dict(sd, strict=True)
|
layer.load_state_dict(sd, strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
|
||||||
|
"""Find models that can accept src_lang as input and return tgt_lang as output."""
|
||||||
|
prefix = "Helsinki-NLP/opus-mt-"
|
||||||
|
api = HfApi()
|
||||||
|
model_list = api.model_list()
|
||||||
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
|
||||||
|
src_and_targ = [
|
||||||
|
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
|
||||||
|
] # + cant be loaded.
|
||||||
|
matching = [f"{prefix}{a}-{b}" for (a, b) in src_and_targ if src_lang in a and tgt_lang in b]
|
||||||
|
return matching
|
||||||
|
|
||||||
|
|
||||||
def add_emb_entries(wemb, final_bias, n_special_tokens=1):
|
def add_emb_entries(wemb, final_bias, n_special_tokens=1):
|
||||||
vsize, d_model = wemb.shape
|
vsize, d_model = wemb.shape
|
||||||
embs_to_add = np.zeros((n_special_tokens, d_model))
|
embs_to_add = np.zeros((n_special_tokens, d_model))
|
||||||
@@ -81,7 +95,12 @@ def find_model_file(dest_dir): # this one better
|
|||||||
return model_file
|
return model_file
|
||||||
|
|
||||||
|
|
||||||
def parse_readmes(repo_path):
|
def make_registry(repo_path="Opus-MT-train/models"):
|
||||||
|
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"repo_path:{repo_path} does not exist: "
|
||||||
|
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
|
||||||
|
)
|
||||||
results = {}
|
results = {}
|
||||||
for p in Path(repo_path).ls():
|
for p in Path(repo_path).ls():
|
||||||
n_dash = p.name.count("-")
|
n_dash = p.name.count("-")
|
||||||
@@ -90,22 +109,53 @@ def parse_readmes(repo_path):
|
|||||||
else:
|
else:
|
||||||
lns = list(open(p / "README.md").readlines())
|
lns = list(open(p / "README.md").readlines())
|
||||||
results[p.name] = _parse_readme(lns)
|
results[p.name] = _parse_readme(lns)
|
||||||
return results
|
return [(k, v["pre-processing"], v["download"]) for k, v in results.items()]
|
||||||
|
|
||||||
|
|
||||||
def download_all_sentencepiece_models(repo_path="Opus-MT-train/models"):
|
CH_GROUP = "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
||||||
"""Requires 300GB"""
|
"""Requires 300GB"""
|
||||||
save_dir = Path("marian_ckpt")
|
save_dir = Path("marian_ckpt")
|
||||||
if not Path(repo_path).exists():
|
dest_dir = Path("marian_converted")
|
||||||
raise ValueError("You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git")
|
dest_dir.mkdir(exist_ok=True)
|
||||||
results: dict = parse_readmes(repo_path)
|
if model_list is None:
|
||||||
for k, v in tqdm(list(results.items())):
|
model_list: list = make_registry(repo_path=repo_path)
|
||||||
if os.path.exists(save_dir / k):
|
for k, prepro, download in tqdm(model_list):
|
||||||
print(f"already have path {k}")
|
if "SentencePiece" not in prepro: # dont convert BPE models.
|
||||||
continue
|
continue
|
||||||
if "SentencePiece" not in v["pre-processing"]:
|
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
|
||||||
|
download_and_unzip(download, save_dir / k)
|
||||||
|
pair_name = k.replace(CH_GROUP, "ch_group")
|
||||||
|
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def lmap(f, x) -> List:
|
||||||
|
return list(map(f, x))
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_test_set(readmes_raw, pair):
|
||||||
|
import wget
|
||||||
|
|
||||||
|
download_url = readmes_raw[pair]["download"]
|
||||||
|
test_set_url = download_url[:-4] + ".test.txt"
|
||||||
|
fname = wget.download(test_set_url, f"opus_test_{pair}.txt")
|
||||||
|
lns = Path(fname).open().readlines()
|
||||||
|
src = lmap(str.strip, lns[::4])
|
||||||
|
gold = lmap(str.strip, lns[1::4])
|
||||||
|
mar_model = lmap(str.strip, lns[2::4])
|
||||||
|
assert len(gold) == len(mar_model) == len(src)
|
||||||
|
os.remove(fname)
|
||||||
|
return src, mar_model, gold
|
||||||
|
|
||||||
|
|
||||||
|
def convert_whole_dir(path=Path("marian_ckpt/")):
|
||||||
|
for subdir in tqdm(list(path.ls())):
|
||||||
|
dest_dir = f"marian_converted/{subdir.name}"
|
||||||
|
if (dest_dir / "pytorch_model.bin").exists():
|
||||||
continue
|
continue
|
||||||
download_and_unzip(v["download"], save_dir / k)
|
convert(source_dir, dest_dir)
|
||||||
|
|
||||||
|
|
||||||
def _parse_readme(lns):
|
def _parse_readme(lns):
|
||||||
@@ -131,7 +181,7 @@ def _parse_readme(lns):
|
|||||||
return subres
|
return subres
|
||||||
|
|
||||||
|
|
||||||
def write_metadata(dest_dir: Path):
|
def save_tokenizer_config(dest_dir: Path):
|
||||||
dname = dest_dir.name.split("-")
|
dname = dest_dir.name.split("-")
|
||||||
dct = dict(target_lang=dname[-1], source_lang="-".join(dname[:-1]))
|
dct = dict(target_lang=dname[-1], source_lang="-".join(dname[:-1]))
|
||||||
save_json(dct, dest_dir / "tokenizer_config.json")
|
save_json(dct, dest_dir / "tokenizer_config.json")
|
||||||
@@ -148,13 +198,17 @@ def add_to_vocab_(vocab: Dict[str, int], special_tokens: List[str]):
|
|||||||
return added
|
return added
|
||||||
|
|
||||||
|
|
||||||
|
def find_vocab_file(model_dir):
|
||||||
|
return list(model_dir.glob("*vocab.yml"))[0]
|
||||||
|
|
||||||
|
|
||||||
def add_special_tokens_to_vocab(model_dir: Path) -> None:
|
def add_special_tokens_to_vocab(model_dir: Path) -> None:
|
||||||
vocab = load_yaml(model_dir / "opus.spm32k-spm32k.vocab.yml")
|
vocab = load_yaml(find_vocab_file(model_dir))
|
||||||
vocab = {k: int(v) for k, v in vocab.items()}
|
vocab = {k: int(v) for k, v in vocab.items()}
|
||||||
num_added = add_to_vocab_(vocab, ["<pad>"])
|
num_added = add_to_vocab_(vocab, ["<pad>"])
|
||||||
print(f"added {num_added} tokens to vocab")
|
print(f"added {num_added} tokens to vocab")
|
||||||
save_json(vocab, model_dir / "vocab.json")
|
save_json(vocab, model_dir / "vocab.json")
|
||||||
write_metadata(model_dir)
|
save_tokenizer_config(model_dir)
|
||||||
|
|
||||||
|
|
||||||
def save_tokenizer(self, save_directory):
|
def save_tokenizer(self, save_directory):
|
||||||
@@ -251,7 +305,6 @@ class OpusState:
|
|||||||
|
|
||||||
# Process decoder.yml
|
# Process decoder.yml
|
||||||
decoder_yml = cast_marian_config(load_yaml(source_dir / "decoder.yml"))
|
decoder_yml = cast_marian_config(load_yaml(source_dir / "decoder.yml"))
|
||||||
# TODO: what are normalize and word-penalty?
|
|
||||||
check_marian_cfg_assumptions(cfg)
|
check_marian_cfg_assumptions(cfg)
|
||||||
self.hf_config = MarianConfig(
|
self.hf_config = MarianConfig(
|
||||||
vocab_size=cfg["vocab_size"],
|
vocab_size=cfg["vocab_size"],
|
||||||
@@ -273,6 +326,9 @@ class OpusState:
|
|||||||
dropout=0.1, # see opus-mt-train repo/transformer-dropout param.
|
dropout=0.1, # see opus-mt-train repo/transformer-dropout param.
|
||||||
# default: add_final_layer_norm=False,
|
# default: add_final_layer_norm=False,
|
||||||
num_beams=decoder_yml["beam-size"],
|
num_beams=decoder_yml["beam-size"],
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
|
bad_words_ids=[[self.pad_token_id]],
|
||||||
|
max_length=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_layer_entries(self):
|
def _check_layer_entries(self):
|
||||||
@@ -349,12 +405,12 @@ def download_and_unzip(url, dest_dir):
|
|||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
def main(source_dir, dest_dir):
|
def convert(source_dir: Path, dest_dir):
|
||||||
dest_dir = Path(dest_dir)
|
dest_dir = Path(dest_dir)
|
||||||
dest_dir.mkdir(exist_ok=True)
|
dest_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
add_special_tokens_to_vocab(source_dir)
|
add_special_tokens_to_vocab(source_dir)
|
||||||
tokenizer = MarianSentencePieceTokenizer.from_pretrained(str(source_dir))
|
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
|
||||||
save_tokenizer(tokenizer, dest_dir)
|
save_tokenizer(tokenizer, dest_dir)
|
||||||
|
|
||||||
opus_state = OpusState(source_dir)
|
opus_state = OpusState(source_dir)
|
||||||
@@ -377,7 +433,7 @@ if __name__ == "__main__":
|
|||||||
source_dir = Path(args.src)
|
source_dir = Path(args.src)
|
||||||
assert source_dir.exists()
|
assert source_dir.exists()
|
||||||
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
|
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
|
||||||
main(source_dir, dest_dir)
|
convert(source_dir, dest_dir)
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(path):
|
def load_yaml(path):
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .configuration_auto import (
|
|||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
)
|
)
|
||||||
|
from .configuration_marian import MarianConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .modeling_albert import (
|
from .modeling_albert import (
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
@@ -98,6 +99,7 @@ from .modeling_flaubert import (
|
|||||||
FlaubertWithLMHeadModel,
|
FlaubertWithLMHeadModel,
|
||||||
)
|
)
|
||||||
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
|
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
|
||||||
|
from .modeling_marian import MarianMTModel
|
||||||
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
@@ -214,6 +216,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||||||
(AlbertConfig, AlbertForMaskedLM),
|
(AlbertConfig, AlbertForMaskedLM),
|
||||||
(CamembertConfig, CamembertForMaskedLM),
|
(CamembertConfig, CamembertForMaskedLM),
|
||||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||||
|
(MarianConfig, MarianMTModel),
|
||||||
(BartConfig, BartForConditionalGeneration),
|
(BartConfig, BartForConditionalGeneration),
|
||||||
(RobertaConfig, RobertaForMaskedLM),
|
(RobertaConfig, RobertaForMaskedLM),
|
||||||
(BertConfig, BertForMaskedLM),
|
(BertConfig, BertForMaskedLM),
|
||||||
|
|||||||
@@ -18,16 +18,30 @@
|
|||||||
from transformers.modeling_bart import BartForConditionalGeneration
|
from transformers.modeling_bart import BartForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
||||||
"opus-mt-en-de": "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/pytorch_model.bin",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MarianMTModel(BartForConditionalGeneration):
|
class MarianMTModel(BartForConditionalGeneration):
|
||||||
"""Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
|
r"""
|
||||||
Model API is identical to BartForConditionalGeneration"""
|
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
|
||||||
|
Model API is identical to BartForConditionalGeneration.
|
||||||
|
Available models are listed at `Model List <https://huggingface.co/models?search=Helsinki-NLP>`__
|
||||||
|
|
||||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
Examples::
|
||||||
|
|
||||||
|
from transformers import MarianTokenizer, MarianMTModel
|
||||||
|
from typing import List
|
||||||
|
src = 'fr' # source language
|
||||||
|
trg = 'en' # target language
|
||||||
|
sample_text = "où est l'arrêt de bus ?"
|
||||||
|
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' # `Model List`__
|
||||||
|
|
||||||
|
model = MarianMTModel.from_pretrained(mname)
|
||||||
|
tok = MarianTokenizer.from_pretrained(mname)
|
||||||
|
batch = tok.prepare_translation_batch(src_texts=[sample_text]) # don't need tgt_text for inference
|
||||||
|
gen = model.generate(**batch) # for forward pass: model(**batch)
|
||||||
|
words: List[str] = tok.decode_batch(gen, skip_special_tokens=True) # returns "Where is the the bus stop ?"
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
|
||||||
|
|
||||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from .configuration_auto import (
|
|||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
)
|
)
|
||||||
|
from .configuration_marian import MarianConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .tokenization_albert import AlbertTokenizer
|
from .tokenization_albert import AlbertTokenizer
|
||||||
from .tokenization_bart import BartTokenizer
|
from .tokenization_bart import BartTokenizer
|
||||||
@@ -49,6 +50,7 @@ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFas
|
|||||||
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||||
from .tokenization_flaubert import FlaubertTokenizer
|
from .tokenization_flaubert import FlaubertTokenizer
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||||
|
from .tokenization_marian import MarianTokenizer
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||||
from .tokenization_reformer import ReformerTokenizer
|
from .tokenization_reformer import ReformerTokenizer
|
||||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||||
@@ -69,6 +71,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
(AlbertConfig, (AlbertTokenizer, None)),
|
(AlbertConfig, (AlbertTokenizer, None)),
|
||||||
(CamembertConfig, (CamembertTokenizer, None)),
|
(CamembertConfig, (CamembertTokenizer, None)),
|
||||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||||
|
(MarianConfig, (MarianTokenizer, None)),
|
||||||
(BartConfig, (BartTokenizer, None)),
|
(BartConfig, (BartTokenizer, None)),
|
||||||
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||||
(ReformerConfig, (ReformerTokenizer, None)),
|
(ReformerConfig, (ReformerTokenizer, None)),
|
||||||
|
|||||||
@@ -22,7 +22,21 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
# Example URL https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/vocab.json
|
# Example URL https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/vocab.json
|
||||||
|
|
||||||
|
|
||||||
class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
class MarianTokenizer(PreTrainedTokenizer):
|
||||||
|
"""Sentencepiece tokenizer for marian. Source and target languages have different SPM models.
|
||||||
|
The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a vocab dictionary.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
from transformers import MarianTokenizer
|
||||||
|
tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||||
|
src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
|
||||||
|
tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
|
||||||
|
batch_enc: BatchEncoding = tok.prepare_translation_batch(src_texts, tgt_texts=tgt_texts)
|
||||||
|
# keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask].
|
||||||
|
# model(**batch) should work
|
||||||
|
"""
|
||||||
|
|
||||||
vocab_files_names = vocab_files_names
|
vocab_files_names = vocab_files_names
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
|
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
|
||||||
@@ -49,6 +63,8 @@ class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
|||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
)
|
)
|
||||||
self.encoder = load_json(vocab)
|
self.encoder = load_json(vocab)
|
||||||
|
if self.unk_token not in self.encoder:
|
||||||
|
raise KeyError("<unk> token must be in vocab")
|
||||||
assert self.pad_token in self.encoder
|
assert self.pad_token in self.encoder
|
||||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
|
||||||
@@ -64,8 +80,11 @@ class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
|||||||
self.spm_target = sentencepiece.SentencePieceProcessor()
|
self.spm_target = sentencepiece.SentencePieceProcessor()
|
||||||
self.spm_target.Load(target_spm)
|
self.spm_target.Load(target_spm)
|
||||||
|
|
||||||
# Note(SS): splitter would require lots of book-keeping.
|
# Multilingual target side: default to using first supported language code.
|
||||||
# self.sentence_splitter = MosesSentenceSplitter(source_lang)
|
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
|
||||||
|
self.tgt_lang_id = None # will not be used unless it is set through prepare_translation_batch
|
||||||
|
|
||||||
|
# Note(SS): sentence_splitter would require lots of book-keeping.
|
||||||
try:
|
try:
|
||||||
from mosestokenizer import MosesPunctuationNormalizer
|
from mosestokenizer import MosesPunctuationNormalizer
|
||||||
|
|
||||||
@@ -75,11 +94,10 @@ class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
|||||||
self.punc_normalizer = lambda x: x
|
self.punc_normalizer = lambda x: x
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
return self.encoder[token]
|
return self.encoder.get(token, self.encoder[self.unk_token])
|
||||||
|
|
||||||
def _tokenize(self, text: str, src=True) -> List[str]:
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
spm = self.spm_source if src else self.spm_target
|
return self.current_spm.EncodeAsPieces(text)
|
||||||
return spm.EncodeAsPieces(text)
|
|
||||||
|
|
||||||
def _convert_id_to_token(self, index: int) -> str:
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
"""Converts an index (integer) in a token (str) using the encoder."""
|
"""Converts an index (integer) in a token (str) using the encoder."""
|
||||||
@@ -89,10 +107,6 @@ class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
|||||||
"""Uses target language sentencepiece model"""
|
"""Uses target language sentencepiece model"""
|
||||||
return self.spm_target.DecodePieces(tokens)
|
return self.spm_target.DecodePieces(tokens)
|
||||||
|
|
||||||
def _append_special_tokens_and_truncate(self, tokens: str, max_length: int,) -> List[int]:
|
|
||||||
ids: list = self.convert_tokens_to_ids(tokens)[:max_length]
|
|
||||||
return ids + [self.eos_token_id]
|
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
@@ -100,7 +114,7 @@ class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||||
|
|
||||||
def decode_batch(self, token_ids, **kwargs) -> List[str]:
|
def batch_decode(self, token_ids, **kwargs) -> List[str]:
|
||||||
return [self.decode(ids, **kwargs) for ids in token_ids]
|
return [self.decode(ids, **kwargs) for ids in token_ids]
|
||||||
|
|
||||||
def prepare_translation_batch(
|
def prepare_translation_batch(
|
||||||
@@ -114,40 +128,38 @@ class MarianSentencePieceTokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
src_texts: list of src language texts
|
src_texts: list of src language texts
|
||||||
src_lang: default en_XX (english)
|
|
||||||
tgt_texts: list of tgt language texts
|
tgt_texts: list of tgt language texts
|
||||||
tgt_lang: default ro_RO (romanian)
|
|
||||||
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
||||||
pad_to_max_length: (bool)
|
pad_to_max_length: (bool)
|
||||||
|
return_tensors: (str) default "pt" returns pytorch tensors, pass None to return lists.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
|
BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
|
||||||
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists)
|
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
|
||||||
|
If no tgt_text is specified, the only keys will be input_ids and attention_mask.
|
||||||
Examples:
|
|
||||||
from transformers import MarianS
|
|
||||||
"""
|
"""
|
||||||
|
self.current_spm = self.spm_source
|
||||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||||
src_texts,
|
src_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
pad_to_max_length=pad_to_max_length,
|
||||||
src=True,
|
|
||||||
)
|
)
|
||||||
if tgt_texts is None:
|
if tgt_texts is None:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
self.current_spm = self.spm_target
|
||||||
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
pad_to_max_length=pad_to_max_length,
|
||||||
src=False,
|
|
||||||
)
|
)
|
||||||
for k, v in decoder_inputs.items():
|
for k, v in decoder_inputs.items():
|
||||||
model_inputs[f"decoder_{k}"] = v
|
model_inputs[f"decoder_{k}"] = v
|
||||||
|
self.current_spm = self.spm_source
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -18,35 +18,94 @@ import unittest
|
|||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.hf_api import HfApi
|
||||||
|
|
||||||
from .utils import require_torch, slow, torch_device
|
from .utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from transformers import MarianMTModel, MarianSentencePieceTokenizer
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
MarianConfig,
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelWithLMHead,
|
||||||
|
MarianTokenizer,
|
||||||
|
MarianMTModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManagementTests(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_model_count(self):
|
||||||
|
model_list = HfApi().model_list()
|
||||||
|
expected_num_models = 1011
|
||||||
|
actual_num_models = len([x for x in model_list if x.modelId.startswith("Helsinki-NLP")])
|
||||||
|
self.assertEqual(expected_num_models, actual_num_models)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class IntegrationTests(unittest.TestCase):
|
class MarianIntegrationTest(unittest.TestCase):
|
||||||
|
src = "en"
|
||||||
|
tgt = "de"
|
||||||
|
src_text = [
|
||||||
|
"I am a small frog.",
|
||||||
|
"Now I can forget the 100 words of german that I know.",
|
||||||
|
"Tom asked his teacher for advice.",
|
||||||
|
"That's how I would do it.",
|
||||||
|
"Tom really admired Mary's courage.",
|
||||||
|
"Turn around and close your eyes.",
|
||||||
|
]
|
||||||
|
expected_text = [
|
||||||
|
"Ich bin ein kleiner Frosch.",
|
||||||
|
"Jetzt kann ich die 100 Wörter des Deutschen vergessen, die ich kenne.",
|
||||||
|
"Tom bat seinen Lehrer um Rat.",
|
||||||
|
"So würde ich das machen.",
|
||||||
|
"Tom bewunderte Marias Mut wirklich.",
|
||||||
|
"Drehen Sie sich um und schließen Sie die Augen.",
|
||||||
|
]
|
||||||
|
# ^^ actual C++ output differs slightly: (1) des Deutschen removed, (2) ""-> "O", (3) tun -> machen
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls) -> None:
|
def setUpClass(cls) -> None:
|
||||||
cls.model_name = "Helsinki-NLP/opus-mt-en-de"
|
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
|
||||||
cls.tokenizer = MarianSentencePieceTokenizer.from_pretrained(cls.model_name)
|
cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||||
cls.eos_token_id = cls.tokenizer.eos_token_id
|
cls.eos_token_id = cls.tokenizer.eos_token_id
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
model = MarianMTModel.from_pretrained(self.model_name).to(torch_device)
|
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)
|
||||||
|
c = model.config
|
||||||
|
self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]])
|
||||||
|
self.assertEqual(c.max_length, 512)
|
||||||
|
self.assertEqual(c.decoder_start_token_id, c.pad_token_id)
|
||||||
|
|
||||||
if torch_device == "cuda":
|
if torch_device == "cuda":
|
||||||
return model.half()
|
return model.half()
|
||||||
else:
|
else:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
||||||
|
generated_words = self.translate_src_text(**tokenizer_kwargs)
|
||||||
|
self.assertListEqual(self.expected_text, generated_words)
|
||||||
|
|
||||||
|
def translate_src_text(self, **tokenizer_kwargs):
|
||||||
|
model_inputs: dict = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
self.assertEqual(self.model.device, model_inputs["input_ids"].device)
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], num_beams=2
|
||||||
|
)
|
||||||
|
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
return generated_words
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||||
@slow
|
@slow
|
||||||
def test_forward(self):
|
def test_forward(self):
|
||||||
src, tgt = ["I am a small frog"], ["▁Ich ▁bin ▁ein ▁kleiner ▁Fro sch"]
|
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||||
expected = [38, 121, 14, 697, 38848, 0]
|
expected = [38, 121, 14, 697, 38848, 0]
|
||||||
|
|
||||||
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
|
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
|
||||||
@@ -62,57 +121,112 @@ class IntegrationTests(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits, *enc_features = self.model(**model_inputs)
|
logits, *enc_features = self.model(**model_inputs)
|
||||||
max_indices = logits.argmax(-1)
|
max_indices = logits.argmax(-1)
|
||||||
self.tokenizer.decode_batch(max_indices)
|
self.tokenizer.batch_decode(max_indices)
|
||||||
|
|
||||||
@slow
|
def test_tokenizer_equivalence(self):
|
||||||
def test_repl_generate_one(self):
|
|
||||||
src = ["I am a small frog.", "Hello"]
|
|
||||||
model_inputs: dict = self.tokenizer.prepare_translation_batch(src).to(torch_device)
|
|
||||||
self.assertEqual(self.model.device, model_inputs["input_ids"].device)
|
|
||||||
generated_ids = self.model.generate(model_inputs["input_ids"], num_beams=6,)
|
|
||||||
generated_words = self.tokenizer.decode_batch(generated_ids)[0]
|
|
||||||
expected_words = "Ich bin ein kleiner Frosch."
|
|
||||||
self.assertEqual(expected_words, generated_words)
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_repl_generate_batch(self):
|
|
||||||
src = [
|
|
||||||
"I am a small frog.",
|
|
||||||
"Now I can forget the 100 words of german that I know.",
|
|
||||||
"O",
|
|
||||||
"Tom asked his teacher for advice.",
|
|
||||||
"That's how I would do it.",
|
|
||||||
"Tom really admired Mary's courage.",
|
|
||||||
"Turn around and close your eyes.",
|
|
||||||
]
|
|
||||||
model_inputs: dict = self.tokenizer.prepare_translation_batch(src).to(torch_device)
|
|
||||||
self.assertEqual(self.model.device, model_inputs["input_ids"].device)
|
|
||||||
generated_ids = self.model.generate(
|
|
||||||
model_inputs["input_ids"],
|
|
||||||
length_penalty=1.0,
|
|
||||||
num_beams=2, # 6 is the default
|
|
||||||
bad_words_ids=[[self.tokenizer.pad_token_id]],
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
"Ich bin ein kleiner Frosch.",
|
|
||||||
"Jetzt kann ich die 100 Wörter des Deutschen vergessen, die ich kenne.",
|
|
||||||
"",
|
|
||||||
"Tom bat seinen Lehrer um Rat.",
|
|
||||||
"So würde ich das tun.",
|
|
||||||
"Tom bewunderte Marias Mut wirklich.",
|
|
||||||
"Umdrehen und die Augen schließen.",
|
|
||||||
]
|
|
||||||
# actual C++ output differences: (1) des Deutschen removed, (2) ""-> "O", (3) tun -> machen
|
|
||||||
generated_words = self.tokenizer.decode_batch(generated_ids, skip_special_tokens=True)
|
|
||||||
self.assertListEqual(expected, generated_words)
|
|
||||||
|
|
||||||
def test_marian_equivalence(self):
|
|
||||||
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
|
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
|
||||||
input_ids = batch["input_ids"][0]
|
input_ids = batch["input_ids"][0]
|
||||||
expected = [38, 121, 14, 697, 38848, 0]
|
expected = [38, 121, 14, 697, 38848, 0]
|
||||||
self.assertListEqual(expected, input_ids.tolist())
|
self.assertListEqual(expected, input_ids.tolist())
|
||||||
|
|
||||||
|
def test_unk_support(self):
|
||||||
|
t = self.tokenizer
|
||||||
|
ids = t.prepare_translation_batch(["||"]).to(torch_device)["input_ids"][0].tolist()
|
||||||
|
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||||
|
self.assertEqual(expected, ids)
|
||||||
|
|
||||||
def test_pad_not_split(self):
|
def test_pad_not_split(self):
|
||||||
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"])["input_ids"][0]
|
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"])["input_ids"][0]
|
||||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||||
self.assertListEqual(expected_w_pad, input_ids_w_pad.tolist())
|
self.assertListEqual(expected_w_pad, input_ids_w_pad.tolist())
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation_en_de(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
def test_auto_config(self):
|
||||||
|
config = AutoConfig.from_pretrained(self.model_name)
|
||||||
|
self.assertIsInstance(config, MarianConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_EN_FR(MarianIntegrationTest):
|
||||||
|
src = "en"
|
||||||
|
tgt = "fr"
|
||||||
|
src_text = [
|
||||||
|
"I am a small frog.",
|
||||||
|
"Now I can forget the 100 words of german that I know.",
|
||||||
|
]
|
||||||
|
expected_text = [
|
||||||
|
"Je suis une petite grenouille.",
|
||||||
|
"Maintenant, je peux oublier les 100 mots d'allemand que je connais.",
|
||||||
|
]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation_en_fr(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_FR_EN(MarianIntegrationTest):
|
||||||
|
src = "fr"
|
||||||
|
tgt = "en"
|
||||||
|
src_text = [
|
||||||
|
"Donnez moi le micro.",
|
||||||
|
"Tom et Mary étaient assis à une table.", # Accents
|
||||||
|
]
|
||||||
|
expected_text = [
|
||||||
|
"Give me the microphone.",
|
||||||
|
"Tom and Mary were sitting at a table.",
|
||||||
|
]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation_fr_en(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_RU_FR(MarianIntegrationTest):
|
||||||
|
src = "ru"
|
||||||
|
tgt = "fr"
|
||||||
|
src_text = ["Он показал мне рукопись своей новой пьесы."]
|
||||||
|
expected_text = ["Il me montre un manuscrit de sa nouvelle pièce."]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation_ru_fr(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_MT_EN(MarianIntegrationTest):
|
||||||
|
src = "mt"
|
||||||
|
tgt = "en"
|
||||||
|
src_text = ["Il - Babiloniżi b'mod żbaljat ikkonkludew li l - Alla l - veru kien dgħajjef."]
|
||||||
|
expected_text = ["The Babylonians wrongly concluded that the true God was weak."]
|
||||||
|
|
||||||
|
@unittest.skip("") # Known Issue: This model generates a string of .... at the end of the translation.
|
||||||
|
def test_batch_generation_mt_en(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_DE_Multi(MarianIntegrationTest):
|
||||||
|
src = "de"
|
||||||
|
tgt = "ch_group"
|
||||||
|
src_text = ["Er aber sprach: Das ist die Gottlosigkeit."]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_translation_de_multi_does_not_error(self):
|
||||||
|
self.translate_src_text()
|
||||||
|
|
||||||
|
@unittest.skip("") # "Language codes are not yet supported."
|
||||||
|
def test_batch_generation_de_multi_tgt(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
@unittest.skip("") # "Language codes are not yet supported."
|
||||||
|
def test_lang_code(self):
|
||||||
|
t = "Er aber sprach"
|
||||||
|
zh_code = self.code
|
||||||
|
tok_fn = self.tokenizer.prepare_translation_batch
|
||||||
|
pass_code = tok_fn(src_texts=[t], tgt_lang_code=zh_code)["input_ids"][0]
|
||||||
|
preprocess_with_code = tok_fn(src_texts=[zh_code + " " + t])["input_ids"][0]
|
||||||
|
self.assertListEqual(pass_code.tolist(), preprocess_with_code.tolist())
|
||||||
|
for code in self.tokenizer.supported_language_codes:
|
||||||
|
self.assertIn(code, self.tokenizer.encoder)
|
||||||
|
pass_only_code = tok_fn(src_texts=[""], tgt_lang_code=zh_code)["input_ids"][0].tolist()
|
||||||
|
self.assertListEqual(pass_only_code, [self.tokenizer.encoder[zh_code], self.tokenizer.eos_token_id])
|
||||||
|
|||||||
Reference in New Issue
Block a user