Add mbart-large-cc25, support translation finetuning (#5129)
improve unittests for finetuning, especially w.r.t testing frozen parameters fix freeze_embeds for T5 add streamlit setup.cfg
This commit is contained in:
@@ -14,11 +14,12 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
@@ -47,6 +48,7 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
assert_all_frozen,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
@@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer):
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||
freeze_params(self.model.get_encoder())
|
||||
assert_all_frozen(self.model.get_encoder())
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer):
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=source_ids,
|
||||
attention_mask=source_mask,
|
||||
use_cache=True,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
)
|
||||
gen_time = (time.time() - t0) / source_ids.shape[0]
|
||||
preds = self.ids_to_clean_text(generated_ids)
|
||||
target = self.ids_to_clean_text(y)
|
||||
@@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule):
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user