Add mbart-large-cc25, support translation finetuning (#5129)

improve unittests for finetuning, especially w.r.t testing frozen parameters
fix freeze_embeds for T5
add streamlit setup.cfg
This commit is contained in:
Sam Shleifer
2020-07-07 13:23:01 -04:00
committed by GitHub
parent 141492448b
commit 353b8f1e7a
14 changed files with 521 additions and 204 deletions

View File

@@ -14,6 +14,8 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
@@ -25,6 +27,7 @@ def encode_file(
prefix="",
tok_name="",
):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
@@ -46,8 +49,8 @@ def encode_file(
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
add_prefix_space=True,
return_tensors=return_tensors,
**extra_kw,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
@@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
n_obs=None,
overwrite_cache=False,
prefix="",
src_lang=None,
tgt_lang=None,
):
super().__init__()
# FIXME: the rstrip logic strips all the chars, it seems.
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
if hasattr(tokenizer, "set_lang") and src_lang is not None:
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
@@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
@@ -224,8 +233,8 @@ def get_git_info():
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):