From 5abe50381afc3a02cb5776f990bf443f83430ef4 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 28 Jul 2020 18:27:58 -0400 Subject: [PATCH] Fix #6096: MBartTokenizer's mask token (#6098) --- src/transformers/tokenization_bart.py | 1 + tests/test_modeling_mbart.py | 12 ++++++++++++ tests/test_tokenization_mbart.py | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index c83ad0d333..499895e0bd 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -122,6 +122,7 @@ class MBartTokenizer(XLMRobertaTokenizer): } self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} self.cur_lang_code = self.lang_code_to_id["en_XX"] + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 0d10a8e406..159fc42976 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -123,6 +123,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest): self.assertEqual(logits.shape, expected_shape) +@require_torch class MBartCC25IntegrationTest(AbstractMBartIntegrationTest): checkpoint_name = "facebook/mbart-large-cc25" src_text = [ @@ -140,3 +141,14 @@ class MBartCC25IntegrationTest(AbstractMBartIntegrationTest): ) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) self.assertEqual(self.tgt_text[0], decoded[0]) + + @slow + def test_fill_mask(self): + inputs = self.tokenizer.prepare_translation_batch(["One of the best I ever read!"]).to(torch_device) + outputs = self.model.generate( + inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 + ) + prediction: str = self.tokenizer.batch_decode( + outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True + )[0] + self.assertEqual(prediction, "of the best books I ever read!") diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 14566ac975..74bfd5b5bf 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -1,3 +1,4 @@ +import tempfile import unittest from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer @@ -171,3 +172,13 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.assertEqual(ids[-2], 2) self.assertEqual(ids[-1], EN_CODE) self.assertEqual(len(ids), desired_max_length) + + def test_mask_token(self): + self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["", "ar_AR"]), [250026, 250001]) + + def test_special_tokens_unaffacted_by_save_load(self): + tmpdirname = tempfile.mkdtemp() + original_special_tokens = self.tokenizer.fairseq_tokens_to_ids + self.tokenizer.save_pretrained(tmpdirname) + new_tok = MBartTokenizer.from_pretrained(tmpdirname) + self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)