[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)
This commit is contained in:
@@ -1,28 +1,90 @@
|
|||||||
MarianMTModel
|
MarianMT
|
||||||
----------------------------------------------------
|
----------------------------------------------------
|
||||||
**DISCLAIMER:** If you see something strange,
|
**DISCLAIMER:** If you see something strange,
|
||||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||||
@sshleifer
|
@sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card.
|
||||||
These models are for machine translation. The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
|
|
||||||
|
|
||||||
Opus Project
|
|
||||||
~~~~~~~~~~~~
|
|
||||||
The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
|
|
||||||
All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
|
|
||||||
|
|
||||||
Implementation Notes
|
Implementation Notes
|
||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
- each model is about 298 MB on disk, there are 1,000+ models.
|
- each model is about 298 MB on disk, there are 1,000+ models.
|
||||||
- Models are named with the following patter 'Helsinki-NLP/opus-mt-{src_langs}-{targ_langs}'. If there are multiple source or target languages they are joined by a '+' symbol.
|
- The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
|
||||||
|
- The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
|
||||||
|
- All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
|
||||||
- the 80 opus models that require BPE preprocessing are not supported.
|
- the 80 opus models that require BPE preprocessing are not supported.
|
||||||
- There is an outstanding issue w.r.t multilingual models and language codes.
|
- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
|
||||||
- The modeling code is the same as ``BartModel`` with a few minor modifications:
|
|
||||||
- static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``)
|
- static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``)
|
||||||
- a new final_logits_bias (``MarianConfig.add_bias_logits=True``)
|
- a new final_logits_bias (``MarianConfig.add_bias_logits=True``)
|
||||||
- no layernorm_embedding (``MarianConfig.normalize_embedding=False``)
|
- no layernorm_embedding (``MarianConfig.normalize_embedding=False``)
|
||||||
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>)
|
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>)
|
||||||
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``
|
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``
|
||||||
|
|
||||||
|
Naming
|
||||||
|
~~~~~~
|
||||||
|
- All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``
|
||||||
|
- The language codes used to name models are inconsistent. Two digit codes can usually be found `here <https://developers.google.com/admin-sdk/directory/v1/languages>`_, three digit codes require googling "language code {code}".
|
||||||
|
- Codes formatted like ``es_AR`` are usually ``code_{region}``. That one is spanish documents from Argentina.
|
||||||
|
|
||||||
|
|
||||||
|
Multilingual Models
|
||||||
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``:
|
||||||
|
- if ``src`` is in all caps, the model supports multiple input languages, you can figure out which ones by looking at the model card, or the Group Members `mapping <https://gist.github.com/sshleifer/6d20e7761931b08e73c3219027b97b8a>`_ .
|
||||||
|
- if ``tgt`` is in all caps, the model can output multiple languages, and you should specify a language code by prepending the desired output language to the src_text
|
||||||
|
- You can see a tokenizer's supported language codes in ``tokenizer.supported_language_codes``
|
||||||
|
|
||||||
|
Example of translating english to many romance languages, using language codes:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from transformers import MarianMTModel, MarianTokenizer
|
||||||
|
src_text = [
|
||||||
|
'>>fr<< this is a sentence in english that we want to translate to french',
|
||||||
|
'>>pt<< This should go to portuguese',
|
||||||
|
'>>es<< And this to Spanish'
|
||||||
|
]
|
||||||
|
|
||||||
|
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||||
|
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||||
|
print(tokenizer.supported_language_codes)
|
||||||
|
model = MarianMTModel.from_pretrained(model_name)
|
||||||
|
translated = model.generate(**tokenizer.prepare_translation_batch(src_text))
|
||||||
|
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||||
|
# ["c'est une phrase en anglais que nous voulons traduire en français",
|
||||||
|
# 'Isto deve ir para o português.',
|
||||||
|
# 'Y esto al español']
|
||||||
|
|
||||||
|
Sometimes, models were trained on collections of languages that do not resolve to a group. In this case, _ is used as a separator for src or tgt, as in ``'Helsinki-NLP/opus-mt-en_el_es_fi-en_el_es_fi'``. These still require language codes.
|
||||||
|
There are many supported regional language codes, like ``>>es_ES<<`` (Spain) and ``>>es_AR<<`` (Argentina), that do not seem to change translations. I have not found these to provide different results than just using ``>>es<<``.
|
||||||
|
|
||||||
|
For Example:
|
||||||
|
- ``Helsinki-NLP/opus-mt-NORTH_EU-NORTH_EU``: translates from all NORTH_EU languages (see `mapping <https://gist.github.com/sshleifer/6d20e7761931b08e73c3219027b97b8a>`_) to all NORTH_EU languages. Use a special language code like ``>>de<<`` to specify output language.
|
||||||
|
- ``Helsinki-NLP/opus-mt-ROMANCE-en``: translates from many romance languages to english, no codes needed since there is only 1 tgt language.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
GROUP_MEMBERS = {
|
||||||
|
'ZH': ['cmn', 'cn', 'yue', 'ze_zh', 'zh_cn', 'zh_CN', 'zh_HK', 'zh_tw', 'zh_TW', 'zh_yue', 'zhs', 'zht', 'zh'],
|
||||||
|
'ROMANCE': ['fr', 'fr_BE', 'fr_CA', 'fr_FR', 'wa', 'frp', 'oc', 'ca', 'rm', 'lld', 'fur', 'lij', 'lmo', 'es', 'es_AR', 'es_CL', 'es_CO', 'es_CR', 'es_DO', 'es_EC', 'es_ES', 'es_GT', 'es_HN', 'es_MX', 'es_NI', 'es_PA', 'es_PE', 'es_PR', 'es_SV', 'es_UY', 'es_VE', 'pt', 'pt_br', 'pt_BR', 'pt_PT', 'gl', 'lad', 'an', 'mwl', 'it', 'it_IT', 'co', 'nap', 'scn', 'vec', 'sc', 'ro', 'la'],
|
||||||
|
'NORTH_EU': ['de', 'nl', 'fy', 'af', 'da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
|
||||||
|
'SCANDINAVIA': ['da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
|
||||||
|
'SAMI': ['se', 'sma', 'smj', 'smn', 'sms'],
|
||||||
|
'NORWAY': ['nb_NO', 'nb', 'nn_NO', 'nn', 'nog', 'no_nb', 'no'],
|
||||||
|
'CELTIC': ['ga', 'cy', 'br', 'gd', 'kw', 'gv']
|
||||||
|
}
|
||||||
|
|
||||||
|
Code to see available pretrained models:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from transformers.hf_api import HfApi
|
||||||
|
model_list = HfApi().model_list()
|
||||||
|
org = "Helsinki-NLP"
|
||||||
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||||
|
suffix = [x.split('/')[1] for x in model_ids]
|
||||||
|
multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()]
|
||||||
|
|
||||||
MarianMTModel
|
MarianMTModel
|
||||||
~~~~~~~~~~~~~
|
~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -95,6 +95,97 @@ def find_model_file(dest_dir): # this one better
|
|||||||
return model_file
|
return model_file
|
||||||
|
|
||||||
|
|
||||||
|
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
|
||||||
|
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
|
||||||
|
GROUPS = [
|
||||||
|
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
|
||||||
|
(ROM_GROUP, "ROMANCE"),
|
||||||
|
("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"),
|
||||||
|
("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"),
|
||||||
|
("se+sma+smj+smn+sms", "SAMI"),
|
||||||
|
("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"),
|
||||||
|
("ga+cy+br+gd+kw+gv", "CELTIC"), # https://en.wikipedia.org/wiki/Insular_Celtic_languages
|
||||||
|
]
|
||||||
|
GROUP_TO_OPUS_NAME = {
|
||||||
|
"opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de",
|
||||||
|
"opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
|
||||||
|
"opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv",
|
||||||
|
"opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv",
|
||||||
|
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
|
||||||
|
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
|
||||||
|
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
|
||||||
|
"opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
|
||||||
|
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
|
||||||
|
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la",
|
||||||
|
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
|
||||||
|
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
|
||||||
|
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
|
||||||
|
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
|
||||||
|
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
|
||||||
|
"opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
|
||||||
|
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
|
||||||
|
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en",
|
||||||
|
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
|
||||||
|
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
|
||||||
|
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
|
||||||
|
}
|
||||||
|
OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/"
|
||||||
|
ORG_NAME = "Helsinki-NLP/"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_opus_name_to_hf_name(x):
|
||||||
|
for substr, grp_name in GROUPS:
|
||||||
|
x = x.replace(substr, grp_name)
|
||||||
|
return x.replace("+", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_hf_name_to_opus_name(hf_model_name):
|
||||||
|
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
|
||||||
|
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||||
|
if hf_model_name in GROUP_TO_OPUS_NAME:
|
||||||
|
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
|
||||||
|
else:
|
||||||
|
opus_w_prefix = hf_model_name.replace("_", "+")
|
||||||
|
return remove_prefix(opus_w_prefix, "opus-mt-")
|
||||||
|
|
||||||
|
|
||||||
|
def write_model_card(
|
||||||
|
hf_model_name: str,
|
||||||
|
repo_path="OPUS-MT-train/models/",
|
||||||
|
dry_run=False,
|
||||||
|
model_card_dir=Path("marian_converted/model_cards/Helsinki-NLP/"),
|
||||||
|
) -> str:
|
||||||
|
"""Copy the most recent model's readme section from opus, and add metadata.
|
||||||
|
upload command: s3cmd sync --recursive model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
|
||||||
|
"""
|
||||||
|
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||||
|
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
|
||||||
|
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
|
||||||
|
readme_url = OPUS_GITHUB_URL + f"{opus_name}/README.md"
|
||||||
|
s, t = ",".join(opus_src), ",".join(opus_tgt)
|
||||||
|
extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n"
|
||||||
|
# combine with opus markdown
|
||||||
|
opus_readme_path = Path(f"{repo_path}{opus_name}/README.md")
|
||||||
|
assert opus_readme_path.exists(), opus_readme_path
|
||||||
|
content = opus_readme_path.open().read()
|
||||||
|
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
||||||
|
content = "*".join(content.split("*")[1:])
|
||||||
|
content = extra_markdown + "\n* " + content.replace("download", "download original weights")
|
||||||
|
if dry_run:
|
||||||
|
return content
|
||||||
|
# Save string to model_cards/hf_model_name/readme.md
|
||||||
|
model_card_dir.mkdir(exist_ok=True)
|
||||||
|
sub_dir = model_card_dir / hf_model_name
|
||||||
|
sub_dir.mkdir(exist_ok=True)
|
||||||
|
dest = sub_dir / "README.md"
|
||||||
|
dest.open("w").write(content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def get_clean_model_id_mapping(multiling_model_ids):
|
||||||
|
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
|
||||||
|
|
||||||
|
|
||||||
def make_registry(repo_path="Opus-MT-train/models"):
|
def make_registry(repo_path="Opus-MT-train/models"):
|
||||||
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -109,10 +200,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
|
|||||||
else:
|
else:
|
||||||
lns = list(open(p / "README.md").readlines())
|
lns = list(open(p / "README.md").readlines())
|
||||||
results[p.name] = _parse_readme(lns)
|
results[p.name] = _parse_readme(lns)
|
||||||
return [(k, v["pre-processing"], v["download"]) for k, v in results.items()]
|
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||||
|
|
||||||
|
|
||||||
CH_GROUP = "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
|
|
||||||
|
|
||||||
|
|
||||||
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
||||||
@@ -122,12 +210,12 @@ def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
|||||||
dest_dir.mkdir(exist_ok=True)
|
dest_dir.mkdir(exist_ok=True)
|
||||||
if model_list is None:
|
if model_list is None:
|
||||||
model_list: list = make_registry(repo_path=repo_path)
|
model_list: list = make_registry(repo_path=repo_path)
|
||||||
for k, prepro, download in tqdm(model_list):
|
for k, prepro, download, test_set_url in tqdm(model_list):
|
||||||
if "SentencePiece" not in prepro: # dont convert BPE models.
|
if "SentencePiece" not in prepro: # dont convert BPE models.
|
||||||
continue
|
continue
|
||||||
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
|
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
|
||||||
download_and_unzip(download, save_dir / k)
|
download_and_unzip(download, save_dir / k)
|
||||||
pair_name = k.replace(CH_GROUP, "ch_group")
|
pair_name = convert_opus_name_to_hf_name(k)
|
||||||
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
||||||
|
|
||||||
|
|
||||||
@@ -135,12 +223,10 @@ def lmap(f, x) -> List:
|
|||||||
return list(map(f, x))
|
return list(map(f, x))
|
||||||
|
|
||||||
|
|
||||||
def fetch_test_set(readmes_raw, pair):
|
def fetch_test_set(test_set_url):
|
||||||
import wget
|
import wget
|
||||||
|
|
||||||
download_url = readmes_raw[pair]["download"]
|
fname = wget.download(test_set_url, f"opus_test.txt")
|
||||||
test_set_url = download_url[:-4] + ".test.txt"
|
|
||||||
fname = wget.download(test_set_url, f"opus_test_{pair}.txt")
|
|
||||||
lns = Path(fname).open().readlines()
|
lns = Path(fname).open().readlines()
|
||||||
src = lmap(str.strip, lns[::4])
|
src = lmap(str.strip, lns[::4])
|
||||||
gold = lmap(str.strip, lns[1::4])
|
gold = lmap(str.strip, lns[1::4])
|
||||||
|
|||||||
@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
||||||
if cur_len == 1:
|
if cur_len == 1:
|
||||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||||
return scores
|
return logits
|
||||||
|
|
||||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||||||
src = 'fr' # source language
|
src = 'fr' # source language
|
||||||
trg = 'en' # target language
|
trg = 'en' # target language
|
||||||
sample_text = "où est l'arrêt de bus ?"
|
sample_text = "où est l'arrêt de bus ?"
|
||||||
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' # `Model List`__
|
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||||
|
|
||||||
model = MarianMTModel.from_pretrained(mname)
|
model = MarianMTModel.from_pretrained(mname)
|
||||||
tok = MarianTokenizer.from_pretrained(mname)
|
tok = MarianTokenizer.from_pretrained(mname)
|
||||||
@@ -43,7 +43,8 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||||||
|
|
||||||
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
|
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
|
||||||
|
|
||||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
||||||
|
logits[:, self.config.pad_token_id] = float("-inf")
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||||
return scores
|
return logits
|
||||||
|
|||||||
@@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
def prepare_scores_for_generation(self, scores, **kwargs):
|
def prepare_logits_for_generation(self, logits, **kwargs):
|
||||||
return scores
|
return logits
|
||||||
|
|
||||||
def _use_cache(self, outputs, use_cache):
|
def _use_cache(self, outputs, use_cache):
|
||||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||||
@@ -857,7 +857,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
Defaults to `None`.
|
Defaults to `None`.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
|
||||||
decoder_start_token_id=None: (`optional`) int
|
decoder_start_token_id=None: (`optional`) int
|
||||||
If an encoder-decoder model starts decoding with a different token than BOS.
|
If an encoder-decoder model starts decoding with a different token than BOS.
|
||||||
@@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
|
|
||||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
||||||
if self.config.is_encoder_decoder and do_sample is False:
|
if self.config.is_encoder_decoder and do_sample is False:
|
||||||
# TODO (PVP) still a bit hacky here - there might be a better solutino
|
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
next_token_logits = self.prepare_logits_for_generation(
|
||||||
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_id is not None and cur_len < min_length:
|
if eos_token_id is not None and cur_len < min_length:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
@@ -14,7 +15,7 @@ vocab_files_names = {
|
|||||||
"vocab": "vocab.json",
|
"vocab": "vocab.json",
|
||||||
"tokenizer_config_file": "tokenizer_config.json",
|
"tokenizer_config_file": "tokenizer_config.json",
|
||||||
}
|
}
|
||||||
MODEL_NAMES = ("opus-mt-en-de",)
|
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
|
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
|
||||||
for k, fname in vocab_files_names.items()
|
for k, fname in vocab_files_names.items()
|
||||||
@@ -41,6 +42,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
|
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
|
||||||
model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
|
model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
|
||||||
|
language_code_re = re.compile(">>.+<<") # type: re.Pattern
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -72,8 +74,6 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
self.target_lang = target_lang
|
self.target_lang = target_lang
|
||||||
|
|
||||||
# load SentencePiece model for pre-processing
|
# load SentencePiece model for pre-processing
|
||||||
self.paths = {}
|
|
||||||
|
|
||||||
self.spm_source = sentencepiece.SentencePieceProcessor()
|
self.spm_source = sentencepiece.SentencePieceProcessor()
|
||||||
self.spm_source.Load(source_spm)
|
self.spm_source.Load(source_spm)
|
||||||
|
|
||||||
@@ -82,9 +82,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
# Multilingual target side: default to using first supported language code.
|
# Multilingual target side: default to using first supported language code.
|
||||||
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
|
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
|
||||||
self.tgt_lang_id = None # will not be used unless it is set through prepare_translation_batch
|
|
||||||
|
|
||||||
# Note(SS): sentence_splitter would require lots of book-keeping.
|
|
||||||
try:
|
try:
|
||||||
from mosestokenizer import MosesPunctuationNormalizer
|
from mosestokenizer import MosesPunctuationNormalizer
|
||||||
|
|
||||||
@@ -93,11 +91,23 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
warnings.warn("Recommended: pip install mosestokenizer")
|
warnings.warn("Recommended: pip install mosestokenizer")
|
||||||
self.punc_normalizer = lambda x: x
|
self.punc_normalizer = lambda x: x
|
||||||
|
|
||||||
|
def normalize(self, x: str) -> str:
|
||||||
|
"""Cover moses empty string edge case. They return empty list for '' input!"""
|
||||||
|
return self.punc_normalizer(x) if x else ""
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
return self.encoder.get(token, self.encoder[self.unk_token])
|
return self.encoder.get(token, self.encoder[self.unk_token])
|
||||||
|
|
||||||
|
def remove_language_code(self, text: str):
|
||||||
|
"""Remove language codes like <<fr>> before sentencepiece"""
|
||||||
|
match = self.language_code_re.match(text)
|
||||||
|
code: list = [match.group(0)] if match else []
|
||||||
|
return code, self.language_code_re.sub("", text)
|
||||||
|
|
||||||
def _tokenize(self, text: str) -> List[str]:
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
return self.current_spm.EncodeAsPieces(text)
|
code, text = self.remove_language_code(text)
|
||||||
|
pieces = self.current_spm.EncodeAsPieces(text)
|
||||||
|
return code + pieces
|
||||||
|
|
||||||
def _convert_id_to_token(self, index: int) -> str:
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
"""Converts an index (integer) in a token (str) using the encoder."""
|
"""Converts an index (integer) in a token (str) using the encoder."""
|
||||||
@@ -125,7 +135,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
pad_to_max_length: bool = True,
|
pad_to_max_length: bool = True,
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
"""
|
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||||
Arguments:
|
Arguments:
|
||||||
src_texts: list of src language texts
|
src_texts: list of src language texts
|
||||||
tgt_texts: list of tgt language texts
|
tgt_texts: list of tgt language texts
|
||||||
@@ -138,7 +148,10 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
|
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
|
||||||
If no tgt_text is specified, the only keys will be input_ids and attention_mask.
|
If no tgt_text is specified, the only keys will be input_ids and attention_mask.
|
||||||
"""
|
"""
|
||||||
|
if "" in src_texts:
|
||||||
|
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||||
self.current_spm = self.spm_source
|
self.current_spm = self.spm_source
|
||||||
|
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
||||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||||
src_texts,
|
src_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
|
|||||||
@@ -33,15 +33,21 @@ if is_torch_available():
|
|||||||
MarianTokenizer,
|
MarianTokenizer,
|
||||||
MarianMTModel,
|
MarianMTModel,
|
||||||
)
|
)
|
||||||
|
from transformers.convert_marian_to_pytorch import (
|
||||||
|
convert_hf_name_to_opus_name,
|
||||||
|
convert_opus_name_to_hf_name,
|
||||||
|
ORG_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelManagementTests(unittest.TestCase):
|
class ModelManagementTests(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_model_count(self):
|
def test_model_names(self):
|
||||||
model_list = HfApi().model_list()
|
model_list = HfApi().model_list()
|
||||||
expected_num_models = 1011
|
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
|
||||||
actual_num_models = len([x for x in model_list if x.modelId.startswith("Helsinki-NLP")])
|
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
|
||||||
self.assertEqual(expected_num_models, actual_num_models)
|
self.assertListEqual([], bad_model_ids)
|
||||||
|
self.assertGreater(len(model_ids), 500)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -91,12 +97,12 @@ class MarianIntegrationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(self.expected_text, generated_words)
|
self.assertListEqual(self.expected_text, generated_words)
|
||||||
|
|
||||||
def translate_src_text(self, **tokenizer_kwargs):
|
def translate_src_text(self, **tokenizer_kwargs):
|
||||||
model_inputs: dict = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
self.assertEqual(self.model.device, model_inputs["input_ids"].device)
|
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], num_beams=2
|
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
|
||||||
)
|
)
|
||||||
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
return generated_words
|
return generated_words
|
||||||
@@ -106,10 +112,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
|||||||
@slow
|
@slow
|
||||||
def test_forward(self):
|
def test_forward(self):
|
||||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||||
expected = [38, 121, 14, 697, 38848, 0]
|
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||||
|
|
||||||
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
|
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
|
||||||
self.assertListEqual(expected, model_inputs["input_ids"][0].tolist())
|
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||||
|
|
||||||
desired_keys = {
|
desired_keys = {
|
||||||
"input_ids",
|
"input_ids",
|
||||||
@@ -125,20 +131,19 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
|||||||
|
|
||||||
def test_tokenizer_equivalence(self):
|
def test_tokenizer_equivalence(self):
|
||||||
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
|
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
|
||||||
input_ids = batch["input_ids"][0]
|
|
||||||
expected = [38, 121, 14, 697, 38848, 0]
|
expected = [38, 121, 14, 697, 38848, 0]
|
||||||
self.assertListEqual(expected, input_ids.tolist())
|
self.assertListEqual(expected, batch.input_ids[0].tolist())
|
||||||
|
|
||||||
def test_unk_support(self):
|
def test_unk_support(self):
|
||||||
t = self.tokenizer
|
t = self.tokenizer
|
||||||
ids = t.prepare_translation_batch(["||"]).to(torch_device)["input_ids"][0].tolist()
|
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
||||||
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||||
self.assertEqual(expected, ids)
|
self.assertEqual(expected, ids)
|
||||||
|
|
||||||
def test_pad_not_split(self):
|
def test_pad_not_split(self):
|
||||||
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"])["input_ids"][0]
|
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
|
||||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||||
self.assertListEqual(expected_w_pad, input_ids_w_pad.tolist())
|
self.assertListEqual(expected_w_pad, input_ids_w_pad)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_batch_generation_en_de(self):
|
def test_batch_generation_en_de(self):
|
||||||
@@ -187,9 +192,8 @@ class TestMarian_RU_FR(MarianIntegrationTest):
|
|||||||
src = "ru"
|
src = "ru"
|
||||||
tgt = "fr"
|
tgt = "fr"
|
||||||
src_text = ["Он показал мне рукопись своей новой пьесы."]
|
src_text = ["Он показал мне рукопись своей новой пьесы."]
|
||||||
expected_text = ["Il me montre un manuscrit de sa nouvelle pièce."]
|
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_batch_generation_ru_fr(self):
|
def test_batch_generation_ru_fr(self):
|
||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
@@ -197,36 +201,59 @@ class TestMarian_RU_FR(MarianIntegrationTest):
|
|||||||
class TestMarian_MT_EN(MarianIntegrationTest):
|
class TestMarian_MT_EN(MarianIntegrationTest):
|
||||||
src = "mt"
|
src = "mt"
|
||||||
tgt = "en"
|
tgt = "en"
|
||||||
src_text = ["Il - Babiloniżi b'mod żbaljat ikkonkludew li l - Alla l - veru kien dgħajjef."]
|
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
|
||||||
expected_text = ["The Babylonians wrongly concluded that the true God was weak."]
|
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
|
||||||
|
|
||||||
@unittest.skip("") # Known Issue: This model generates a string of .... at the end of the translation.
|
|
||||||
def test_batch_generation_mt_en(self):
|
def test_batch_generation_mt_en(self):
|
||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
class TestMarian_DE_Multi(MarianIntegrationTest):
|
class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||||
src = "de"
|
"""Multilingual on target side."""
|
||||||
tgt = "ch_group"
|
|
||||||
src_text = ["Er aber sprach: Das ist die Gottlosigkeit."]
|
src = "en"
|
||||||
|
tgt = "ROMANCE"
|
||||||
|
src_text = [
|
||||||
|
">>fr<< Don't spend so much time watching TV.",
|
||||||
|
">>pt<< Your message has been sent.",
|
||||||
|
">>es<< He's two years older than me.",
|
||||||
|
]
|
||||||
|
expected_text = [
|
||||||
|
"Ne passez pas autant de temps à regarder la télé.",
|
||||||
|
"A sua mensagem foi enviada.",
|
||||||
|
"Es dos años más viejo que yo.",
|
||||||
|
]
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_translation_de_multi_does_not_error(self):
|
def test_batch_generation_en_ROMANCE_multi(self):
|
||||||
self.translate_src_text()
|
|
||||||
|
|
||||||
@unittest.skip("") # "Language codes are not yet supported."
|
|
||||||
def test_batch_generation_de_multi_tgt(self):
|
|
||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
@unittest.skip("") # "Language codes are not yet supported."
|
def test_tokenizer_handles_empty(self):
|
||||||
def test_lang_code(self):
|
normalized = self.tokenizer.normalize("")
|
||||||
t = "Er aber sprach"
|
self.assertIsInstance(normalized, str)
|
||||||
zh_code = self.code
|
with self.assertRaises(ValueError):
|
||||||
tok_fn = self.tokenizer.prepare_translation_batch
|
self.tokenizer.prepare_translation_batch([""])
|
||||||
pass_code = tok_fn(src_texts=[t], tgt_lang_code=zh_code)["input_ids"][0]
|
|
||||||
preprocess_with_code = tok_fn(src_texts=[zh_code + " " + t])["input_ids"][0]
|
|
||||||
self.assertListEqual(pass_code.tolist(), preprocess_with_code.tolist())
|
@require_torch
|
||||||
for code in self.tokenizer.supported_language_codes:
|
class TestConversionUtils(unittest.TestCase):
|
||||||
self.assertIn(code, self.tokenizer.encoder)
|
def test_renaming_multilingual(self):
|
||||||
pass_only_code = tok_fn(src_texts=[""], tgt_lang_code=zh_code)["input_ids"][0].tolist()
|
old_names = [
|
||||||
self.assertListEqual(pass_only_code, [self.tokenizer.encoder[zh_code], self.tokenizer.eos_token_id])
|
"opus-mt-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
|
||||||
|
"opus-mt-cmn+cn-fi", # no group
|
||||||
|
"opus-mt-en-de", # standard name
|
||||||
|
"opus-mt-en-de", # standard name
|
||||||
|
]
|
||||||
|
expected = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]
|
||||||
|
self.assertListEqual(expected, [convert_opus_name_to_hf_name(x) for x in old_names])
|
||||||
|
|
||||||
|
def test_undoing_renaming(self):
|
||||||
|
hf_names = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]
|
||||||
|
converted_opus_names = [convert_hf_name_to_opus_name(x) for x in hf_names]
|
||||||
|
expected_opus_names = [
|
||||||
|
"cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
|
||||||
|
"cmn+cn-fi",
|
||||||
|
"en-de", # standard name
|
||||||
|
"en-de",
|
||||||
|
]
|
||||||
|
self.assertListEqual(expected_opus_names, converted_opus_names)
|
||||||
|
|||||||
Reference in New Issue
Block a user