[marian] converter supports models from new Tatoeba project (#6342)
This commit is contained in:
@@ -1,14 +1,14 @@
|
|||||||
MarianMT
|
MarianMT
|
||||||
----------------------------------------------------
|
----------------------------------------------------
|
||||||
**DISCLAIMER:** If you see something strange,
|
**Bugs:** If you see something strange,
|
||||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=sshleifer&labels=&template=bug-report.md&title>`__ and assign
|
||||||
@sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card.
|
@sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card.
|
||||||
|
|
||||||
Implementation Notes
|
Implementation Notes
|
||||||
~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
- Each model is about 298 MB on disk, there are 1,000+ models.
|
- Each model is about 298 MB on disk, there are 1,000+ models.
|
||||||
- The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
|
- The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
|
||||||
- The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
|
- models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
|
||||||
- All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
|
- All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
|
||||||
- The 80 opus models that require BPE preprocessing are not supported.
|
- The 80 opus models that require BPE preprocessing are not supported.
|
||||||
- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
|
- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -15,6 +17,87 @@ from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
|||||||
from transformers.hf_api import HfApi
|
from transformers.hf_api import HfApi
|
||||||
|
|
||||||
|
|
||||||
|
def remove_suffix(text: str, suffix: str):
|
||||||
|
if text.endswith(suffix):
|
||||||
|
return text[: -len(suffix)]
|
||||||
|
return text # or whatever
|
||||||
|
|
||||||
|
|
||||||
|
def _process_benchmark_table_row(x):
|
||||||
|
fields = lmap(str.strip, x.replace("\t", "").split("|")[1:-1])
|
||||||
|
assert len(fields) == 3
|
||||||
|
return (fields[0], float(fields[1]), float(fields[2]))
|
||||||
|
|
||||||
|
|
||||||
|
def process_last_benchmark_table(readme_path) -> List[Tuple[str, float, float]]:
|
||||||
|
md_content = Path(readme_path).open().read()
|
||||||
|
entries = md_content.split("## Benchmarks")[-1].strip().split("\n")[2:]
|
||||||
|
data = lmap(_process_benchmark_table_row, entries)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo_path="Tatoeba-Challenge/models/"):
|
||||||
|
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
released_cols = [
|
||||||
|
"url_base",
|
||||||
|
"pair", # (ISO639-3/ISO639-5 codes),
|
||||||
|
"short_pair", # (reduced codes),
|
||||||
|
"chrF2_score",
|
||||||
|
"bleu",
|
||||||
|
"brevity_penalty",
|
||||||
|
"ref_len",
|
||||||
|
"src_name",
|
||||||
|
"tgt_name",
|
||||||
|
]
|
||||||
|
|
||||||
|
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
|
||||||
|
released.columns = released_cols
|
||||||
|
old_reg = make_registry(repo_path=old_repo_path)
|
||||||
|
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
|
||||||
|
assert old_reg.id.value_counts().max() == 1
|
||||||
|
old_reg = old_reg.set_index("id")
|
||||||
|
|
||||||
|
released["fname"] = released["url_base"].apply(
|
||||||
|
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
|
||||||
|
)
|
||||||
|
|
||||||
|
released["2m"] = released.fname.str.startswith("2m")
|
||||||
|
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
|
||||||
|
|
||||||
|
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
|
||||||
|
|
||||||
|
short_to_new_bleu = newest_released.set_index("short_pair").bleu
|
||||||
|
|
||||||
|
assert released.groupby("short_pair").pair.nunique().max() == 1
|
||||||
|
|
||||||
|
short_to_long = released.groupby("short_pair").pair.first().to_dict()
|
||||||
|
|
||||||
|
overlap_short = old_reg.index.intersection(released.short_pair.unique())
|
||||||
|
overlap_long = [short_to_long[o] for o in overlap_short]
|
||||||
|
new_reported_bleu = [short_to_new_bleu[o] for o in overlap_short]
|
||||||
|
|
||||||
|
def get_old_bleu(o) -> float:
|
||||||
|
pat = old_repo_path + "/{}/README.md"
|
||||||
|
bm_data = process_last_benchmark_table(pat.format(o))
|
||||||
|
tab = pd.DataFrame(bm_data, columns=["testset", "bleu", "chr-f"])
|
||||||
|
tato_bleu = tab.loc[lambda x: x.testset.str.startswith("Tato")].bleu
|
||||||
|
if tato_bleu.shape[0] > 0:
|
||||||
|
return tato_bleu.iloc[0]
|
||||||
|
else:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
old_bleu = [get_old_bleu(o) for o in overlap_short]
|
||||||
|
cmp_df = pd.DataFrame(
|
||||||
|
dict(short=overlap_short, long=overlap_long, old_bleu=old_bleu, new_bleu=new_reported_bleu)
|
||||||
|
).fillna(-1)
|
||||||
|
|
||||||
|
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
|
||||||
|
blacklist = dominated.long.unique().tolist() # 3 letter codes
|
||||||
|
return dominated, blacklist
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(text: str, prefix: str):
|
def remove_prefix(text: str, prefix: str):
|
||||||
if text.startswith(prefix):
|
if text.startswith(prefix):
|
||||||
return text[len(prefix) :]
|
return text[len(prefix) :]
|
||||||
@@ -149,37 +232,87 @@ def convert_hf_name_to_opus_name(hf_model_name):
|
|||||||
return remove_prefix(opus_w_prefix, "opus-mt-")
|
return remove_prefix(opus_w_prefix, "opus-mt-")
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_metadata(repo_root):
|
||||||
|
import git
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
helsinki_git_sha=git.Repo(path=repo_root, search_parent_directories=True).head.object.hexsha,
|
||||||
|
transformers_git_sha=git.Repo(path=".", search_parent_directories=True).head.object.hexsha,
|
||||||
|
port_machine=socket.gethostname(),
|
||||||
|
port_time=time.strftime("%Y-%m-%d-%H:%M"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
front_matter = """---
|
||||||
|
language: {}
|
||||||
|
tags:
|
||||||
|
- translation
|
||||||
|
|
||||||
|
license: apache-2.0
|
||||||
|
---
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def write_model_card(
|
def write_model_card(
|
||||||
hf_model_name: str,
|
hf_model_name: str, repo_root="OPUS-MT-train", save_dir=Path("marian_converted"), dry_run=False, extra_metadata={},
|
||||||
repo_path="OPUS-MT-train/models/",
|
|
||||||
dry_run=False,
|
|
||||||
model_card_dir=Path("marian_converted/model_cards/Helsinki-NLP/"),
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Copy the most recent model's readme section from opus, and add metadata.
|
"""Copy the most recent model's readme section from opus, and add metadata.
|
||||||
upload command: s3cmd sync --recursive model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
|
upload command: aws s3 sync model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
|
||||||
"""
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||||
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
|
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
|
||||||
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
|
assert repo_root in ("OPUS-MT-train", "Tatoeba-Challenge")
|
||||||
readme_url = OPUS_GITHUB_URL + f"{opus_name}/README.md"
|
opus_readme_path = Path(repo_root).joinpath("models", opus_name, "README.md")
|
||||||
s, t = ",".join(opus_src), ",".join(opus_tgt)
|
|
||||||
extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n"
|
|
||||||
# combine with opus markdown
|
|
||||||
opus_readme_path = Path(f"{repo_path}{opus_name}/README.md")
|
|
||||||
assert opus_readme_path.exists(), f"Readme file {opus_readme_path} not found"
|
assert opus_readme_path.exists(), f"Readme file {opus_readme_path} not found"
|
||||||
|
|
||||||
|
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
|
||||||
|
|
||||||
|
readme_url = f"https://github.com/Helsinki-NLP/{repo_root}/tree/master/models/{opus_name}/README.md"
|
||||||
|
|
||||||
|
s, t = ",".join(opus_src), ",".join(opus_tgt)
|
||||||
|
metadata = {
|
||||||
|
"hf_name": hf_model_name,
|
||||||
|
"source_languages": s,
|
||||||
|
"target_languages": t,
|
||||||
|
"opus_readme_url": readme_url,
|
||||||
|
"original_repo": repo_root,
|
||||||
|
"tags": ["translation"],
|
||||||
|
}
|
||||||
|
metadata.update(extra_metadata)
|
||||||
|
metadata.update(get_system_metadata(repo_root))
|
||||||
|
|
||||||
|
# combine with opus markdown
|
||||||
|
|
||||||
|
extra_markdown = f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: {metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
|
||||||
|
|
||||||
content = opus_readme_path.open().read()
|
content = opus_readme_path.open().read()
|
||||||
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
||||||
content = "*".join(content.split("*")[1:])
|
splat = content.split("*")[2:]
|
||||||
content = extra_markdown + "\n* " + content.replace("download", "download original weights")
|
print(splat[3])
|
||||||
|
content = "*".join(splat)
|
||||||
|
content = (
|
||||||
|
front_matter.format(metadata["src_alpha2"])
|
||||||
|
+ extra_markdown
|
||||||
|
+ "\n* "
|
||||||
|
+ content.replace("download", "download original weights")
|
||||||
|
)
|
||||||
|
|
||||||
|
items = "\n\n".join([f"- {k}: {v}" for k, v in metadata.items()])
|
||||||
|
sec3 = "\n### System Info: \n" + items
|
||||||
|
content += sec3
|
||||||
if dry_run:
|
if dry_run:
|
||||||
return content
|
return content, metadata
|
||||||
# Save string to model_cards/hf_model_name/readme.md
|
sub_dir = save_dir / f"opus-mt-{hf_model_name}"
|
||||||
model_card_dir.mkdir(exist_ok=True)
|
|
||||||
sub_dir = model_card_dir / hf_model_name
|
|
||||||
sub_dir.mkdir(exist_ok=True)
|
sub_dir.mkdir(exist_ok=True)
|
||||||
dest = sub_dir / "README.md"
|
dest = sub_dir / "README.md"
|
||||||
dest.open("w").write(content)
|
dest.open("w").write(content)
|
||||||
return content
|
pd.Series(metadata).to_json(sub_dir / "metadata.json")
|
||||||
|
|
||||||
|
# if dry_run:
|
||||||
|
return content, metadata
|
||||||
|
|
||||||
|
|
||||||
def get_clean_model_id_mapping(multiling_model_ids):
|
def get_clean_model_id_mapping(multiling_model_ids):
|
||||||
@@ -193,7 +326,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
|
|||||||
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
|
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
|
||||||
)
|
)
|
||||||
results = {}
|
results = {}
|
||||||
for p in Path(repo_path).ls():
|
for p in Path(repo_path).iterdir():
|
||||||
n_dash = p.name.count("-")
|
n_dash = p.name.count("-")
|
||||||
if n_dash == 0:
|
if n_dash == 0:
|
||||||
continue
|
continue
|
||||||
@@ -203,6 +336,21 @@ def make_registry(repo_path="Opus-MT-train/models"):
|
|||||||
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||||
|
|
||||||
|
|
||||||
|
def make_tatoeba_registry(repo_path="Tatoeba-Challenge/models"):
|
||||||
|
if not (Path(repo_path) / "zho-eng" / "README.md").exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"repo_path:{repo_path} does not exist: "
|
||||||
|
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
|
||||||
|
)
|
||||||
|
results = {}
|
||||||
|
for p in Path(repo_path).iterdir():
|
||||||
|
if len(p.name) != 7:
|
||||||
|
continue
|
||||||
|
lns = list(open(p / "README.md").readlines())
|
||||||
|
results[p.name] = _parse_readme(lns)
|
||||||
|
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||||
|
|
||||||
|
|
||||||
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
||||||
"""Requires 300GB"""
|
"""Requires 300GB"""
|
||||||
save_dir = Path("marian_ckpt")
|
save_dir = Path("marian_ckpt")
|
||||||
@@ -516,19 +664,6 @@ def convert(source_dir: Path, dest_dir):
|
|||||||
model.from_pretrained(dest_dir) # sanity check
|
model.from_pretrained(dest_dir) # sanity check
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
# Required parameters
|
|
||||||
parser.add_argument("--src", type=str, help="path to marian model dir", default="en-de")
|
|
||||||
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
source_dir = Path(args.src)
|
|
||||||
assert source_dir.exists(), f"Source directory {source_dir} not found"
|
|
||||||
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
|
|
||||||
convert(source_dir, dest_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(path):
|
def load_yaml(path):
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@@ -544,3 +679,23 @@ def save_json(content: Union[Dict, List], path: str) -> None:
|
|||||||
def unzip(zip_path: str, dest_dir: str) -> None:
|
def unzip(zip_path: str, dest_dir: str) -> None:
|
||||||
with ZipFile(zip_path, "r") as zipObj:
|
with ZipFile(zip_path, "r") as zipObj:
|
||||||
zipObj.extractall(dest_dir)
|
zipObj.extractall(dest_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
To bulk convert, run
|
||||||
|
>>> from transformers.convert_marian_to_pytorch import make_tatoeba_registry, convert_all_sentencepiece_models
|
||||||
|
>>> reg = make_tatoeba_registry()
|
||||||
|
>>> convert_all_sentencepiece_models(model_list=reg) # saves to marian_converted
|
||||||
|
(bash) aws s3 sync marian_converted s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument("--src", type=str, help="path to marian model dir", default="en-de")
|
||||||
|
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
source_dir = Path(args.src)
|
||||||
|
assert source_dir.exists(), f"Source directory {source_dir} not found"
|
||||||
|
dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
|
||||||
|
convert(source_dir, dest_dir)
|
||||||
|
|||||||
@@ -205,6 +205,17 @@ class TestMarian_MT_EN(MarianIntegrationTest):
|
|||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarian_eng_zho(MarianIntegrationTest):
|
||||||
|
src = "eng"
|
||||||
|
tgt = "zho"
|
||||||
|
src_text = ["My name is Wolfgang and I live in Berlin"]
|
||||||
|
expected_text = ["我叫沃尔夫冈 我住在柏林"]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_batch_generation_eng_zho(self):
|
||||||
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||||
"""Multilingual on target side."""
|
"""Multilingual on target side."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user