MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)

* MBART: support summarization tasks

* fix test

* Style

* add tokenizer test
This commit is contained in:
Sam Shleifer
2020-07-28 08:18:11 -04:00
committed by GitHub
parent 842eb45606
commit 3c7fbf35a6
7 changed files with 38 additions and 15 deletions

View File

@@ -105,7 +105,13 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
else:
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -331,11 +337,6 @@ class TranslationModule(SummarizationModule):
super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)