From 9a687ebb7750edee86047e0f243c332effdf86b6 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 13 May 2020 17:29:41 -0400 Subject: [PATCH] [Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290) --- docs/source/model_doc/marian.rst | 84 ++++++++++++-- src/transformers/convert_marian_to_pytorch.py | 106 +++++++++++++++-- src/transformers/modeling_bart.py | 8 +- src/transformers/modeling_marian.py | 9 +- src/transformers/modeling_utils.py | 15 ++- src/transformers/tokenization_marian.py | 27 +++-- tests/test_modeling_marian.py | 107 +++++++++++------- 7 files changed, 274 insertions(+), 82 deletions(-) diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index 6874368011..ef72e93e13 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -1,28 +1,90 @@ -MarianMTModel +MarianMT ---------------------------------------------------- **DISCLAIMER:** If you see something strange, file a `Github Issue `__ and assign -@sshleifer -These models are for machine translation. The list of supported language pairs can be found `here `__. - -Opus Project -~~~~~~~~~~~~ -The 1,000+ models were originally trained by `Jörg Tiedemann `__ using the `Marian `_ C++ library, which supports fast training and translation. -All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card. +@sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card. Implementation Notes ~~~~~~~~~~~~~~~~~~~~ - each model is about 298 MB on disk, there are 1,000+ models. -- Models are named with the following patter 'Helsinki-NLP/opus-mt-{src_langs}-{targ_langs}'. If there are multiple source or target languages they are joined by a '+' symbol. +- The list of supported language pairs can be found `here `__. +- The 1,000+ models were originally trained by `Jörg Tiedemann `__ using the `Marian `_ C++ library, which supports fast training and translation. +- All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card. - the 80 opus models that require BPE preprocessing are not supported. -- There is an outstanding issue w.r.t multilingual models and language codes. -- The modeling code is the same as ``BartModel`` with a few minor modifications: +- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications: - static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``) - a new final_logits_bias (``MarianConfig.add_bias_logits=True``) - no layernorm_embedding (``MarianConfig.normalize_embedding=False``) - the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses ) - Code to bulk convert models can be found in ``convert_marian_to_pytorch.py`` +Naming +~~~~~~ +- All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}`` +- The language codes used to name models are inconsistent. Two digit codes can usually be found `here `_, three digit codes require googling "language code {code}". +- Codes formatted like ``es_AR`` are usually ``code_{region}``. That one is spanish documents from Argentina. + + +Multilingual Models +~~~~~~~~~~~~~~~~~~~~ + +All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``: + - if ``src`` is in all caps, the model supports multiple input languages, you can figure out which ones by looking at the model card, or the Group Members `mapping `_ . + - if ``tgt`` is in all caps, the model can output multiple languages, and you should specify a language code by prepending the desired output language to the src_text + - You can see a tokenizer's supported language codes in ``tokenizer.supported_language_codes`` + +Example of translating english to many romance languages, using language codes: + +.. code-block:: python + + from transformers import MarianMTModel, MarianTokenizer + src_text = [ + '>>fr<< this is a sentence in english that we want to translate to french', + '>>pt<< This should go to portuguese', + '>>es<< And this to Spanish' + ] + + model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE' + tokenizer = MarianTokenizer.from_pretrained(model_name) + print(tokenizer.supported_language_codes) + model = MarianMTModel.from_pretrained(model_name) + translated = model.generate(**tokenizer.prepare_translation_batch(src_text)) + tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] + # ["c'est une phrase en anglais que nous voulons traduire en français", + # 'Isto deve ir para o português.', + # 'Y esto al español'] + +Sometimes, models were trained on collections of languages that do not resolve to a group. In this case, _ is used as a separator for src or tgt, as in ``'Helsinki-NLP/opus-mt-en_el_es_fi-en_el_es_fi'``. These still require language codes. +There are many supported regional language codes, like ``>>es_ES<<`` (Spain) and ``>>es_AR<<`` (Argentina), that do not seem to change translations. I have not found these to provide different results than just using ``>>es<<``. + +For Example: + - ``Helsinki-NLP/opus-mt-NORTH_EU-NORTH_EU``: translates from all NORTH_EU languages (see `mapping `_) to all NORTH_EU languages. Use a special language code like ``>>de<<`` to specify output language. + - ``Helsinki-NLP/opus-mt-ROMANCE-en``: translates from many romance languages to english, no codes needed since there is only 1 tgt language. + + + +.. code-block:: python + + GROUP_MEMBERS = { + 'ZH': ['cmn', 'cn', 'yue', 'ze_zh', 'zh_cn', 'zh_CN', 'zh_HK', 'zh_tw', 'zh_TW', 'zh_yue', 'zhs', 'zht', 'zh'], + 'ROMANCE': ['fr', 'fr_BE', 'fr_CA', 'fr_FR', 'wa', 'frp', 'oc', 'ca', 'rm', 'lld', 'fur', 'lij', 'lmo', 'es', 'es_AR', 'es_CL', 'es_CO', 'es_CR', 'es_DO', 'es_EC', 'es_ES', 'es_GT', 'es_HN', 'es_MX', 'es_NI', 'es_PA', 'es_PE', 'es_PR', 'es_SV', 'es_UY', 'es_VE', 'pt', 'pt_br', 'pt_BR', 'pt_PT', 'gl', 'lad', 'an', 'mwl', 'it', 'it_IT', 'co', 'nap', 'scn', 'vec', 'sc', 'ro', 'la'], + 'NORTH_EU': ['de', 'nl', 'fy', 'af', 'da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'], + 'SCANDINAVIA': ['da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'], + 'SAMI': ['se', 'sma', 'smj', 'smn', 'sms'], + 'NORWAY': ['nb_NO', 'nb', 'nn_NO', 'nn', 'nog', 'no_nb', 'no'], + 'CELTIC': ['ga', 'cy', 'br', 'gd', 'kw', 'gv'] + } + +Code to see available pretrained models: + +.. code-block:: python + + from transformers.hf_api import HfApi + model_list = HfApi().model_list() + org = "Helsinki-NLP" + model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] + suffix = [x.split('/')[1] for x in model_ids] + multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()] MarianMTModel ~~~~~~~~~~~~~ diff --git a/src/transformers/convert_marian_to_pytorch.py b/src/transformers/convert_marian_to_pytorch.py index 281bd4aff8..c140fafca4 100644 --- a/src/transformers/convert_marian_to_pytorch.py +++ b/src/transformers/convert_marian_to_pytorch.py @@ -95,6 +95,97 @@ def find_model_file(dest_dir): # this one better return model_file +# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE +ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la" +GROUPS = [ + ("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"), + (ROM_GROUP, "ROMANCE"), + ("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"), + ("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"), + ("se+sma+smj+smn+sms", "SAMI"), + ("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"), + ("ga+cy+br+gd+kw+gv", "CELTIC"), # https://en.wikipedia.org/wiki/Insular_Celtic_languages +] +GROUP_TO_OPUS_NAME = { + "opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de", + "opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi", + "opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv", + "opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv", + "opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv", + "opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", + "opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi", + "opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO" + "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR" + "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la", + "opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv", + "opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no", + "opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms", + "opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", + "opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no", + "opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO" + "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR" + "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en", + "opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en", + "opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", + "opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no", +} +OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/" +ORG_NAME = "Helsinki-NLP/" + + +def convert_opus_name_to_hf_name(x): + for substr, grp_name in GROUPS: + x = x.replace(substr, grp_name) + return x.replace("+", "_") + + +def convert_hf_name_to_opus_name(hf_model_name): + """Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME.""" + hf_model_name = remove_prefix(hf_model_name, ORG_NAME) + if hf_model_name in GROUP_TO_OPUS_NAME: + opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name] + else: + opus_w_prefix = hf_model_name.replace("_", "+") + return remove_prefix(opus_w_prefix, "opus-mt-") + + +def write_model_card( + hf_model_name: str, + repo_path="OPUS-MT-train/models/", + dry_run=False, + model_card_dir=Path("marian_converted/model_cards/Helsinki-NLP/"), +) -> str: + """Copy the most recent model's readme section from opus, and add metadata. + upload command: s3cmd sync --recursive model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/ + """ + hf_model_name = remove_prefix(hf_model_name, ORG_NAME) + opus_name: str = convert_hf_name_to_opus_name(hf_model_name) + opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")] + readme_url = OPUS_GITHUB_URL + f"{opus_name}/README.md" + s, t = ",".join(opus_src), ",".join(opus_tgt) + extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n" + # combine with opus markdown + opus_readme_path = Path(f"{repo_path}{opus_name}/README.md") + assert opus_readme_path.exists(), opus_readme_path + content = opus_readme_path.open().read() + content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model. + content = "*".join(content.split("*")[1:]) + content = extra_markdown + "\n* " + content.replace("download", "download original weights") + if dry_run: + return content + # Save string to model_cards/hf_model_name/readme.md + model_card_dir.mkdir(exist_ok=True) + sub_dir = model_card_dir / hf_model_name + sub_dir.mkdir(exist_ok=True) + dest = sub_dir / "README.md" + dest.open("w").write(content) + return content + + +def get_clean_model_id_mapping(multiling_model_ids): + return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids} + + def make_registry(repo_path="Opus-MT-train/models"): if not (Path(repo_path) / "fr-en" / "README.md").exists(): raise ValueError( @@ -109,10 +200,7 @@ def make_registry(repo_path="Opus-MT-train/models"): else: lns = list(open(p / "README.md").readlines()) results[p.name] = _parse_readme(lns) - return [(k, v["pre-processing"], v["download"]) for k, v in results.items()] - - -CH_GROUP = "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh" + return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()] def convert_all_sentencepiece_models(model_list=None, repo_path=None): @@ -122,12 +210,12 @@ def convert_all_sentencepiece_models(model_list=None, repo_path=None): dest_dir.mkdir(exist_ok=True) if model_list is None: model_list: list = make_registry(repo_path=repo_path) - for k, prepro, download in tqdm(model_list): + for k, prepro, download, test_set_url in tqdm(model_list): if "SentencePiece" not in prepro: # dont convert BPE models. continue if not os.path.exists(save_dir / k / "pytorch_model.bin"): download_and_unzip(download, save_dir / k) - pair_name = k.replace(CH_GROUP, "ch_group") + pair_name = convert_opus_name_to_hf_name(k) convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}") @@ -135,12 +223,10 @@ def lmap(f, x) -> List: return list(map(f, x)) -def fetch_test_set(readmes_raw, pair): +def fetch_test_set(test_set_url): import wget - download_url = readmes_raw[pair]["download"] - test_set_url = download_url[:-4] + ".test.txt" - fname = wget.download(test_set_url, f"opus_test_{pair}.txt") + fname = wget.download(test_set_url, f"opus_test.txt") lns = Path(fname).open().readlines() src = lmap(str.strip, lns[::4]) gold = lmap(str.strip, lns[1::4]) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index a461a6b478..227a440c9d 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - def prepare_scores_for_generation(self, scores, cur_len, max_length): + def prepare_logits_for_generation(self, logits, cur_len, max_length): if cur_len == 1: - self._force_token_ids_generation(scores, self.config.bos_token_id) + self._force_token_ids_generation(logits, self.config.bos_token_id) if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_ids_generation(scores, self.config.eos_token_id) - return scores + self._force_token_ids_generation(logits, self.config.eos_token_id) + return logits def _force_token_ids_generation(self, scores, token_ids) -> None: """force one of token_ids to be generated by setting prob of all other tokens to 0""" diff --git a/src/transformers/modeling_marian.py b/src/transformers/modeling_marian.py index 1953b479fc..701eedda43 100644 --- a/src/transformers/modeling_marian.py +++ b/src/transformers/modeling_marian.py @@ -31,7 +31,7 @@ class MarianMTModel(BartForConditionalGeneration): src = 'fr' # source language trg = 'en' # target language sample_text = "où est l'arrêt de bus ?" - mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' # `Model List`__ + mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' model = MarianMTModel.from_pretrained(mname) tok = MarianTokenizer.from_pretrained(mname) @@ -43,7 +43,8 @@ class MarianMTModel(BartForConditionalGeneration): pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP - def prepare_scores_for_generation(self, scores, cur_len, max_length): + def prepare_logits_for_generation(self, logits, cur_len, max_length): + logits[:, self.config.pad_token_id] = float("-inf") if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_ids_generation(scores, self.config.eos_token_id) - return scores + self._force_token_ids_generation(logits, self.config.eos_token_id) + return logits diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 02ae2240bc..d5d06134bb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} - def prepare_scores_for_generation(self, scores, **kwargs): - return scores + def prepare_logits_for_generation(self, logits, **kwargs): + return logits def _use_cache(self, outputs, use_cache): """During generation, decide whether to pass the `past` variable to the next forward pass.""" @@ -857,7 +857,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. Defaults to `None`. - `What are attention masks? <../glossary.html#attention-mask>`__ + `What are attention masks? <../glossary.html#attention-mask>`__ decoder_start_token_id=None: (`optional`) int If an encoder-decoder model starts decoding with a different token than BOS. @@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if temperature != 1.0: next_token_logits = next_token_logits / temperature - scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) if self.config.is_encoder_decoder and do_sample is False: - # TODO (PVP) still a bit hacky here - there might be a better solutino - scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length) + # TODO (PVP) still a bit hacky here - there might be a better solution + next_token_logits = self.prepare_logits_for_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) + + scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) # set eos token prob to zero if min_length is not reached if eos_token_id is not None and cur_len < min_length: diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index 7524a41d41..bd56382cb7 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -1,4 +1,5 @@ import json +import re import warnings from typing import Dict, List, Optional, Union @@ -14,7 +15,7 @@ vocab_files_names = { "vocab": "vocab.json", "tokenizer_config_file": "tokenizer_config.json", } -MODEL_NAMES = ("opus-mt-en-de",) +MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names PRETRAINED_VOCAB_FILES_MAP = { k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES} for k, fname in vocab_files_names.items() @@ -41,6 +42,7 @@ class MarianTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = {m: 512 for m in MODEL_NAMES} model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask + language_code_re = re.compile(">>.+<<") # type: re.Pattern def __init__( self, @@ -72,8 +74,6 @@ class MarianTokenizer(PreTrainedTokenizer): self.target_lang = target_lang # load SentencePiece model for pre-processing - self.paths = {} - self.spm_source = sentencepiece.SentencePieceProcessor() self.spm_source.Load(source_spm) @@ -82,9 +82,7 @@ class MarianTokenizer(PreTrainedTokenizer): # Multilingual target side: default to using first supported language code. self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")] - self.tgt_lang_id = None # will not be used unless it is set through prepare_translation_batch - # Note(SS): sentence_splitter would require lots of book-keeping. try: from mosestokenizer import MosesPunctuationNormalizer @@ -93,11 +91,23 @@ class MarianTokenizer(PreTrainedTokenizer): warnings.warn("Recommended: pip install mosestokenizer") self.punc_normalizer = lambda x: x + def normalize(self, x: str) -> str: + """Cover moses empty string edge case. They return empty list for '' input!""" + return self.punc_normalizer(x) if x else "" + def _convert_token_to_id(self, token): return self.encoder.get(token, self.encoder[self.unk_token]) + def remove_language_code(self, text: str): + """Remove language codes like <> before sentencepiece""" + match = self.language_code_re.match(text) + code: list = [match.group(0)] if match else [] + return code, self.language_code_re.sub("", text) + def _tokenize(self, text: str) -> List[str]: - return self.current_spm.EncodeAsPieces(text) + code, text = self.remove_language_code(text) + pieces = self.current_spm.EncodeAsPieces(text) + return code + pieces def _convert_id_to_token(self, index: int) -> str: """Converts an index (integer) in a token (str) using the encoder.""" @@ -125,7 +135,7 @@ class MarianTokenizer(PreTrainedTokenizer): pad_to_max_length: bool = True, return_tensors: str = "pt", ) -> BatchEncoding: - """ + """Prepare model inputs for translation. For best performance, translate one sentence at a time. Arguments: src_texts: list of src language texts tgt_texts: list of tgt language texts @@ -138,7 +148,10 @@ class MarianTokenizer(PreTrainedTokenizer): all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists). If no tgt_text is specified, the only keys will be input_ids and attention_mask. """ + if "" in src_texts: + raise ValueError(f"found empty string in src_texts: {src_texts}") self.current_spm = self.spm_source + src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much model_inputs: BatchEncoding = self.batch_encode_plus( src_texts, add_special_tokens=True, diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 059b4cef5a..3858d273ab 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -33,15 +33,21 @@ if is_torch_available(): MarianTokenizer, MarianMTModel, ) + from transformers.convert_marian_to_pytorch import ( + convert_hf_name_to_opus_name, + convert_opus_name_to_hf_name, + ORG_NAME, + ) class ModelManagementTests(unittest.TestCase): @slow - def test_model_count(self): + def test_model_names(self): model_list = HfApi().model_list() - expected_num_models = 1011 - actual_num_models = len([x for x in model_list if x.modelId.startswith("Helsinki-NLP")]) - self.assertEqual(expected_num_models, actual_num_models) + model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)] + bad_model_ids = [mid for mid in model_ids if "+" in model_ids] + self.assertListEqual([], bad_model_ids) + self.assertGreater(len(model_ids), 500) @require_torch @@ -91,12 +97,12 @@ class MarianIntegrationTest(unittest.TestCase): self.assertListEqual(self.expected_text, generated_words) def translate_src_text(self, **tokenizer_kwargs): - model_inputs: dict = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to( + model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to( torch_device ) - self.assertEqual(self.model.device, model_inputs["input_ids"].device) + self.assertEqual(self.model.device, model_inputs.input_ids.device) generated_ids = self.model.generate( - model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], num_beams=2 + model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2 ) generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return generated_words @@ -106,10 +112,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): @slow def test_forward(self): src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] - expected = [38, 121, 14, 697, 38848, 0] + expected_ids = [38, 121, 14, 697, 38848, 0] model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device) - self.assertListEqual(expected, model_inputs["input_ids"][0].tolist()) + self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) desired_keys = { "input_ids", @@ -125,20 +131,19 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): def test_tokenizer_equivalence(self): batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device) - input_ids = batch["input_ids"][0] expected = [38, 121, 14, 697, 38848, 0] - self.assertListEqual(expected, input_ids.tolist()) + self.assertListEqual(expected, batch.input_ids[0].tolist()) def test_unk_support(self): t = self.tokenizer - ids = t.prepare_translation_batch(["||"]).to(torch_device)["input_ids"][0].tolist() + ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist() expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id] self.assertEqual(expected, ids) def test_pad_not_split(self): - input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog "])["input_ids"][0] + input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog "]).input_ids[0].tolist() expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad - self.assertListEqual(expected_w_pad, input_ids_w_pad.tolist()) + self.assertListEqual(expected_w_pad, input_ids_w_pad) @slow def test_batch_generation_en_de(self): @@ -187,9 +192,8 @@ class TestMarian_RU_FR(MarianIntegrationTest): src = "ru" tgt = "fr" src_text = ["Он показал мне рукопись своей новой пьесы."] - expected_text = ["Il me montre un manuscrit de sa nouvelle pièce."] + expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."] - @slow def test_batch_generation_ru_fr(self): self._assert_generated_batch_equal_expected() @@ -197,36 +201,59 @@ class TestMarian_RU_FR(MarianIntegrationTest): class TestMarian_MT_EN(MarianIntegrationTest): src = "mt" tgt = "en" - src_text = ["Il - Babiloniżi b'mod żbaljat ikkonkludew li l - Alla l - veru kien dgħajjef."] - expected_text = ["The Babylonians wrongly concluded that the true God was weak."] + src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] + expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."] - @unittest.skip("") # Known Issue: This model generates a string of .... at the end of the translation. def test_batch_generation_mt_en(self): self._assert_generated_batch_equal_expected() -class TestMarian_DE_Multi(MarianIntegrationTest): - src = "de" - tgt = "ch_group" - src_text = ["Er aber sprach: Das ist die Gottlosigkeit."] +class TestMarian_en_ROMANCE(MarianIntegrationTest): + """Multilingual on target side.""" + + src = "en" + tgt = "ROMANCE" + src_text = [ + ">>fr<< Don't spend so much time watching TV.", + ">>pt<< Your message has been sent.", + ">>es<< He's two years older than me.", + ] + expected_text = [ + "Ne passez pas autant de temps à regarder la télé.", + "A sua mensagem foi enviada.", + "Es dos años más viejo que yo.", + ] @slow - def test_translation_de_multi_does_not_error(self): - self.translate_src_text() - - @unittest.skip("") # "Language codes are not yet supported." - def test_batch_generation_de_multi_tgt(self): + def test_batch_generation_en_ROMANCE_multi(self): self._assert_generated_batch_equal_expected() - @unittest.skip("") # "Language codes are not yet supported." - def test_lang_code(self): - t = "Er aber sprach" - zh_code = self.code - tok_fn = self.tokenizer.prepare_translation_batch - pass_code = tok_fn(src_texts=[t], tgt_lang_code=zh_code)["input_ids"][0] - preprocess_with_code = tok_fn(src_texts=[zh_code + " " + t])["input_ids"][0] - self.assertListEqual(pass_code.tolist(), preprocess_with_code.tolist()) - for code in self.tokenizer.supported_language_codes: - self.assertIn(code, self.tokenizer.encoder) - pass_only_code = tok_fn(src_texts=[""], tgt_lang_code=zh_code)["input_ids"][0].tolist() - self.assertListEqual(pass_only_code, [self.tokenizer.encoder[zh_code], self.tokenizer.eos_token_id]) + def test_tokenizer_handles_empty(self): + normalized = self.tokenizer.normalize("") + self.assertIsInstance(normalized, str) + with self.assertRaises(ValueError): + self.tokenizer.prepare_translation_batch([""]) + + +@require_torch +class TestConversionUtils(unittest.TestCase): + def test_renaming_multilingual(self): + old_names = [ + "opus-mt-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi", + "opus-mt-cmn+cn-fi", # no group + "opus-mt-en-de", # standard name + "opus-mt-en-de", # standard name + ] + expected = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"] + self.assertListEqual(expected, [convert_opus_name_to_hf_name(x) for x in old_names]) + + def test_undoing_renaming(self): + hf_names = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"] + converted_opus_names = [convert_hf_name_to_opus_name(x) for x in hf_names] + expected_opus_names = [ + "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi", + "cmn+cn-fi", + "en-de", # standard name + "en-de", + ] + self.assertListEqual(expected_opus_names, converted_opus_names)