MBART: support summarization tasks where max_src_len > max_tgt_len (#6003)
* MBART: support summarization tasks * fix test * Style * add tokenizer test
This commit is contained in:
@@ -180,6 +180,8 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
|
|||||||
--task summarization \
|
--task summarization \
|
||||||
--n_obs 100 \
|
--n_obs 100 \
|
||||||
--device cuda \
|
--device cuda \
|
||||||
|
--max_source_length 1024 \
|
||||||
|
--max_target_length 56 \
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--bs 32
|
--bs 32
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -105,6 +105,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||||
self.num_workers = hparams.num_workers
|
self.num_workers = hparams.num_workers
|
||||||
self.decoder_start_token_id = None
|
self.decoder_start_token_id = None
|
||||||
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||||
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
|
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||||
|
if isinstance(self.tokenizer, MBartTokenizer):
|
||||||
|
self.dataset_class = MBartDataset
|
||||||
|
else:
|
||||||
self.dataset_class = Seq2SeqDataset
|
self.dataset_class = Seq2SeqDataset
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
@@ -331,11 +337,6 @@ class TranslationModule(SummarizationModule):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
|
||||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
|
||||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
|
||||||
if isinstance(self.tokenizer, MBartTokenizer):
|
|
||||||
self.dataset_class = MBartDataset
|
|
||||||
|
|
||||||
def calc_generative_metrics(self, preds, target) -> dict:
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
return calculate_bleu_score(preds, target)
|
return calculate_bleu_score(preds, target)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ python finetune.py \
|
|||||||
--eval_batch_size=$BS \
|
--eval_batch_size=$BS \
|
||||||
--output_dir=$OUTPUT_DIR \
|
--output_dir=$OUTPUT_DIR \
|
||||||
--max_source_length=512 \
|
--max_source_length=512 \
|
||||||
|
--max_target_length=56 \
|
||||||
--val_check_interval=0.1 --n_val=200 \
|
--val_check_interval=0.1 --n_val=200 \
|
||||||
--do_train --do_predict \
|
--do_train --do_predict \
|
||||||
$@
|
$@
|
||||||
|
|||||||
@@ -300,14 +300,17 @@ def test_mbart_dataset_truncation():
|
|||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
trunc = 4
|
max_src_len = 4
|
||||||
|
max_tgt_len = 8
|
||||||
|
assert max_len_target > max_src_len # Truncated
|
||||||
|
assert max_len_source > max_src_len
|
||||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||||
train_dataset = MBartDataset(
|
train_dataset = MBartDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
max_source_length=trunc,
|
max_source_length=max_src_len,
|
||||||
max_target_length=1000, # ignored
|
max_target_length=max_tgt_len, # ignored
|
||||||
src_lang=src_lang,
|
src_lang=src_lang,
|
||||||
tgt_lang=tgt_lang,
|
tgt_lang=tgt_lang,
|
||||||
)
|
)
|
||||||
@@ -316,17 +319,15 @@ def test_mbart_dataset_truncation():
|
|||||||
assert isinstance(batch, dict)
|
assert isinstance(batch, dict)
|
||||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||||
# show that articles were trimmed.
|
# show that articles were trimmed.
|
||||||
assert batch["input_ids"].shape[1] == trunc
|
assert batch["input_ids"].shape[1] == max_src_len
|
||||||
# show that targets are the same len
|
# show that targets are the same len
|
||||||
assert batch["decoder_input_ids"].shape[1] == trunc
|
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
|
||||||
# check language codes in correct place
|
# check language codes in correct place
|
||||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||||
|
|
||||||
assert max_len_target > trunc # Truncated
|
|
||||||
assert max_len_source > trunc
|
|
||||||
break # No need to test every batch
|
break # No need to test every batch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -157,7 +157,8 @@ class MBartDataset(Seq2SeqDataset):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.max_source_length != self.max_target_length:
|
if self.max_source_length != self.max_target_length:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
|
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
|
||||||
|
f"Imbalanced sequence lengths may be undesired for translation tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, index) -> Dict[str, str]:
|
def __getitem__(self, index) -> Dict[str, str]:
|
||||||
@@ -178,6 +179,7 @@ class MBartDataset(Seq2SeqDataset):
|
|||||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
tgt_lang=self.tgt_lang,
|
tgt_lang=self.tgt_lang,
|
||||||
max_length=self.max_source_length,
|
max_length=self.max_source_length,
|
||||||
|
max_target_length=self.max_target_length,
|
||||||
)
|
)
|
||||||
return batch_encoding.data
|
return batch_encoding.data
|
||||||
|
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
tgt_lang: str = "ro_RO",
|
tgt_lang: str = "ro_RO",
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
|
max_target_length: Optional[int] = None,
|
||||||
padding: str = "longest",
|
padding: str = "longest",
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -224,13 +225,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
)
|
)
|
||||||
if tgt_texts is None:
|
if tgt_texts is None:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
# Process tgt_texts
|
||||||
|
if max_target_length is None:
|
||||||
|
max_target_length = max_length
|
||||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||||
decoder_inputs: BatchEncoding = self(
|
decoder_inputs: BatchEncoding = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
max_length=max_length,
|
max_length=max_target_length,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -137,6 +137,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||||
|
|
||||||
|
def test_max_target_length(self):
|
||||||
|
|
||||||
|
batch = self.tokenizer.prepare_translation_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||||
|
)
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||||
|
# max_target_length will default to max_length if not specified
|
||||||
|
batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||||
|
|
||||||
def test_enro_tokenizer_batch_encode_plus(self):
|
def test_enro_tokenizer_batch_encode_plus(self):
|
||||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
self.assertListEqual(self.expected_src_tokens, ids)
|
self.assertListEqual(self.expected_src_tokens, ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user