[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)
This commit is contained in:
@@ -95,6 +95,97 @@ def find_model_file(dest_dir): # this one better
|
||||
return model_file
|
||||
|
||||
|
||||
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
|
||||
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
|
||||
GROUPS = [
|
||||
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
|
||||
(ROM_GROUP, "ROMANCE"),
|
||||
("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"),
|
||||
("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"),
|
||||
("se+sma+smj+smn+sms", "SAMI"),
|
||||
("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"),
|
||||
("ga+cy+br+gd+kw+gv", "CELTIC"), # https://en.wikipedia.org/wiki/Insular_Celtic_languages
|
||||
]
|
||||
GROUP_TO_OPUS_NAME = {
|
||||
"opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de",
|
||||
"opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
|
||||
"opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv",
|
||||
"opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv",
|
||||
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
|
||||
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
|
||||
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
|
||||
"opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
|
||||
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
|
||||
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la",
|
||||
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
|
||||
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
|
||||
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
|
||||
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
|
||||
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
|
||||
"opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
|
||||
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
|
||||
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en",
|
||||
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
|
||||
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
|
||||
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
|
||||
}
|
||||
OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/"
|
||||
ORG_NAME = "Helsinki-NLP/"
|
||||
|
||||
|
||||
def convert_opus_name_to_hf_name(x):
|
||||
for substr, grp_name in GROUPS:
|
||||
x = x.replace(substr, grp_name)
|
||||
return x.replace("+", "_")
|
||||
|
||||
|
||||
def convert_hf_name_to_opus_name(hf_model_name):
|
||||
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
|
||||
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||
if hf_model_name in GROUP_TO_OPUS_NAME:
|
||||
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
|
||||
else:
|
||||
opus_w_prefix = hf_model_name.replace("_", "+")
|
||||
return remove_prefix(opus_w_prefix, "opus-mt-")
|
||||
|
||||
|
||||
def write_model_card(
|
||||
hf_model_name: str,
|
||||
repo_path="OPUS-MT-train/models/",
|
||||
dry_run=False,
|
||||
model_card_dir=Path("marian_converted/model_cards/Helsinki-NLP/"),
|
||||
) -> str:
|
||||
"""Copy the most recent model's readme section from opus, and add metadata.
|
||||
upload command: s3cmd sync --recursive model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
|
||||
"""
|
||||
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
|
||||
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
|
||||
readme_url = OPUS_GITHUB_URL + f"{opus_name}/README.md"
|
||||
s, t = ",".join(opus_src), ",".join(opus_tgt)
|
||||
extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n"
|
||||
# combine with opus markdown
|
||||
opus_readme_path = Path(f"{repo_path}{opus_name}/README.md")
|
||||
assert opus_readme_path.exists(), opus_readme_path
|
||||
content = opus_readme_path.open().read()
|
||||
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
||||
content = "*".join(content.split("*")[1:])
|
||||
content = extra_markdown + "\n* " + content.replace("download", "download original weights")
|
||||
if dry_run:
|
||||
return content
|
||||
# Save string to model_cards/hf_model_name/readme.md
|
||||
model_card_dir.mkdir(exist_ok=True)
|
||||
sub_dir = model_card_dir / hf_model_name
|
||||
sub_dir.mkdir(exist_ok=True)
|
||||
dest = sub_dir / "README.md"
|
||||
dest.open("w").write(content)
|
||||
return content
|
||||
|
||||
|
||||
def get_clean_model_id_mapping(multiling_model_ids):
|
||||
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
|
||||
|
||||
|
||||
def make_registry(repo_path="Opus-MT-train/models"):
|
||||
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
||||
raise ValueError(
|
||||
@@ -109,10 +200,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
|
||||
else:
|
||||
lns = list(open(p / "README.md").readlines())
|
||||
results[p.name] = _parse_readme(lns)
|
||||
return [(k, v["pre-processing"], v["download"]) for k, v in results.items()]
|
||||
|
||||
|
||||
CH_GROUP = "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
|
||||
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||
|
||||
|
||||
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
||||
@@ -122,12 +210,12 @@ def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
if model_list is None:
|
||||
model_list: list = make_registry(repo_path=repo_path)
|
||||
for k, prepro, download in tqdm(model_list):
|
||||
for k, prepro, download, test_set_url in tqdm(model_list):
|
||||
if "SentencePiece" not in prepro: # dont convert BPE models.
|
||||
continue
|
||||
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
|
||||
download_and_unzip(download, save_dir / k)
|
||||
pair_name = k.replace(CH_GROUP, "ch_group")
|
||||
pair_name = convert_opus_name_to_hf_name(k)
|
||||
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
||||
|
||||
|
||||
@@ -135,12 +223,10 @@ def lmap(f, x) -> List:
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
def fetch_test_set(readmes_raw, pair):
|
||||
def fetch_test_set(test_set_url):
|
||||
import wget
|
||||
|
||||
download_url = readmes_raw[pair]["download"]
|
||||
test_set_url = download_url[:-4] + ".test.txt"
|
||||
fname = wget.download(test_set_url, f"opus_test_{pair}.txt")
|
||||
fname = wget.download(test_set_url, f"opus_test.txt")
|
||||
lns = Path(fname).open().readlines()
|
||||
src = lmap(str.strip, lns[::4])
|
||||
gold = lmap(str.strip, lns[1::4])
|
||||
|
||||
@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
||||
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||
|
||||
@@ -31,7 +31,7 @@ class MarianMTModel(BartForConditionalGeneration):
|
||||
src = 'fr' # source language
|
||||
trg = 'en' # target language
|
||||
sample_text = "où est l'arrêt de bus ?"
|
||||
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' # `Model List`__
|
||||
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
model = MarianMTModel.from_pretrained(mname)
|
||||
tok = MarianTokenizer.from_pretrained(mname)
|
||||
@@ -43,7 +43,8 @@ class MarianMTModel(BartForConditionalGeneration):
|
||||
|
||||
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.pad_token_id] = float("-inf")
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, **kwargs):
|
||||
return scores
|
||||
def prepare_logits_for_generation(self, logits, **kwargs):
|
||||
return logits
|
||||
|
||||
def _use_cache(self, outputs, use_cache):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
@@ -857,7 +857,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
Defaults to `None`.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
|
||||
decoder_start_token_id=None: (`optional`) int
|
||||
If an encoder-decoder model starts decoding with a different token than BOS.
|
||||
@@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solutino
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
||||
# TODO (PVP) still a bit hacky here - there might be a better solution
|
||||
next_token_logits = self.prepare_logits_for_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -14,7 +15,7 @@ vocab_files_names = {
|
||||
"vocab": "vocab.json",
|
||||
"tokenizer_config_file": "tokenizer_config.json",
|
||||
}
|
||||
MODEL_NAMES = ("opus-mt-en-de",)
|
||||
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
|
||||
for k, fname in vocab_files_names.items()
|
||||
@@ -41,6 +42,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
|
||||
model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
|
||||
language_code_re = re.compile(">>.+<<") # type: re.Pattern
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,8 +74,6 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
self.target_lang = target_lang
|
||||
|
||||
# load SentencePiece model for pre-processing
|
||||
self.paths = {}
|
||||
|
||||
self.spm_source = sentencepiece.SentencePieceProcessor()
|
||||
self.spm_source.Load(source_spm)
|
||||
|
||||
@@ -82,9 +82,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
|
||||
# Multilingual target side: default to using first supported language code.
|
||||
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
|
||||
self.tgt_lang_id = None # will not be used unless it is set through prepare_translation_batch
|
||||
|
||||
# Note(SS): sentence_splitter would require lots of book-keeping.
|
||||
try:
|
||||
from mosestokenizer import MosesPunctuationNormalizer
|
||||
|
||||
@@ -93,11 +91,23 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
warnings.warn("Recommended: pip install mosestokenizer")
|
||||
self.punc_normalizer = lambda x: x
|
||||
|
||||
def normalize(self, x: str) -> str:
|
||||
"""Cover moses empty string edge case. They return empty list for '' input!"""
|
||||
return self.punc_normalizer(x) if x else ""
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
return self.encoder.get(token, self.encoder[self.unk_token])
|
||||
|
||||
def remove_language_code(self, text: str):
|
||||
"""Remove language codes like <<fr>> before sentencepiece"""
|
||||
match = self.language_code_re.match(text)
|
||||
code: list = [match.group(0)] if match else []
|
||||
return code, self.language_code_re.sub("", text)
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
return self.current_spm.EncodeAsPieces(text)
|
||||
code, text = self.remove_language_code(text)
|
||||
pieces = self.current_spm.EncodeAsPieces(text)
|
||||
return code + pieces
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the encoder."""
|
||||
@@ -125,7 +135,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
pad_to_max_length: bool = True,
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||
Arguments:
|
||||
src_texts: list of src language texts
|
||||
tgt_texts: list of tgt language texts
|
||||
@@ -138,7 +148,10 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
|
||||
If no tgt_text is specified, the only keys will be input_ids and attention_mask.
|
||||
"""
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
self.current_spm = self.spm_source
|
||||
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
|
||||
Reference in New Issue
Block a user