Add mbart-large-cc25, support translation finetuning (#5129)
improve unittests for finetuning, especially w.r.t testing frozen parameters fix freeze_embeds for T5 add streamlit setup.cfg
This commit is contained in:
@@ -14,6 +14,8 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
@@ -25,6 +27,7 @@ def encode_file(
|
||||
prefix="",
|
||||
tok_name="",
|
||||
):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||
if not overwrite_cache and cache_path.exists():
|
||||
try:
|
||||
@@ -46,8 +49,8 @@ def encode_file(
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
add_prefix_space=True,
|
||||
return_tensors=return_tensors,
|
||||
**extra_kw,
|
||||
)
|
||||
assert tokenized.input_ids.shape[1] == max_length
|
||||
examples.append(tokenized)
|
||||
@@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
|
||||
n_obs=None,
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
):
|
||||
super().__init__()
|
||||
# FIXME: the rstrip logic strips all the chars, it seems.
|
||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
||||
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
@@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
|
||||
)
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
if hasattr(tokenizer, "set_lang"):
|
||||
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
||||
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
||||
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
@@ -224,8 +233,8 @@ def get_git_info():
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||
|
||||
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
|
||||
Reference in New Issue
Block a user