MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)

* MBART: support summarization tasks

* fix test

* Style

* add tokenizer test
This commit is contained in:
Sam Shleifer
2020-07-28 08:18:11 -04:00
committed by GitHub
parent 842eb45606
commit 3c7fbf35a6
7 changed files with 38 additions and 15 deletions

View File

@@ -157,7 +157,8 @@ class MBartDataset(Seq2SeqDataset):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
f"Imbalanced sequence lengths may be undesired for translation tasks"
)
def __getitem__(self, index) -> Dict[str, str]:
@@ -178,6 +179,7 @@ class MBartDataset(Seq2SeqDataset):
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
max_target_length=self.max_target_length,
)
return batch_encoding.data