MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)
* MBART: support summarization tasks * fix test * Style * add tokenizer test
This commit is contained in:
@@ -157,7 +157,8 @@ class MBartDataset(Seq2SeqDataset):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.max_source_length != self.max_target_length:
|
||||
warnings.warn(
|
||||
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
|
||||
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
|
||||
f"Imbalanced sequence lengths may be undesired for translation tasks"
|
||||
)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, str]:
|
||||
@@ -178,6 +179,7 @@ class MBartDataset(Seq2SeqDataset):
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.tgt_lang,
|
||||
max_length=self.max_source_length,
|
||||
max_target_length=self.max_target_length,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user