From 6fc940ed09419cf83288f9ac32e09f10046ec508 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 15 Feb 2021 20:58:54 +0530 Subject: [PATCH] Add mBART-50 (#10154) * add tokenizer for mBART-50 * update tokenizers * make src_lang and tgt_lang optional * update tokenizer test * add setter * update docs * update conversion script * update docs * update conversion script * update tokenizer * update test * update docs * doc * address Sylvain's suggestions * fix test * fix formatting * nits --- README.md | 1 + docs/source/index.rst | 33 +- docs/source/model_doc/mbart.rst | 109 ++++++- docs/source/pretrained_models.rst | 9 + src/transformers/__init__.py | 6 +- src/transformers/convert_slow_tokenizer.py | 30 ++ src/transformers/models/mbart/__init__.py | 4 + ...rt_mbart_original_checkpoint_to_pytorch.py | 46 ++- .../models/mbart/tokenization_mbart50.py | 308 ++++++++++++++++++ .../models/mbart/tokenization_mbart50_fast.py | 278 ++++++++++++++++ .../utils/dummy_sentencepiece_objects.py | 9 + .../utils/dummy_tokenizers_objects.py | 9 + tests/test_tokenization_mbart50.py | 200 ++++++++++++ 13 files changed, 1008 insertions(+), 34 deletions(-) create mode 100644 src/transformers/models/mbart/tokenization_mbart50.py create mode 100644 src/transformers/models/mbart/tokenization_mbart50_fast.py create mode 100644 tests/test_tokenization_mbart50.py diff --git a/README.md b/README.md index 7e98e88d14..cae90de239 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. +1. **[MBart-50](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[MPNet](https://huggingface.co/transformers/model_doc/mpnet.html)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/transformers/model_doc/mt5.html)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777)> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/docs/source/index.rst b/docs/source/index.rst index b28ec05b6b..63ddcbba5c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -161,48 +161,51 @@ and conversion utilities for the following models: 26. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -27. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +27. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible + Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, + Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. +28. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -28. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +29. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -29. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +30. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -30. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +31. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -31. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +32. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -32. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +33. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -33. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +34. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -34. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +35. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -35. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +36. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -36. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +37. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -37. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +38. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -38. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +39. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -39. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +40. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -40. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +41. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -41. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +42. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index 0a9c1cff90..05bbec1cbe 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -10,14 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -MBart +MBart and MBart-50 ----------------------------------------------------------------------------------------------------------------------- **DISCLAIMER:** If you see something strange, file a `Github Issue `__ and assign @patrickvonplaten -Overview +Overview of MBart ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation @@ -31,17 +31,9 @@ on the encoder, decoder, or reconstructing parts of the text. The Authors' code can be found `here `__ -Examples +Training of MBart _______________________________________________________________________________________________________________________ -- Examples and scripts for fine-tuning mBART and other models for sequence to sequence tasks can be found in - :prefix_link:`examples/seq2seq/ `. -- Given the large embeddings table, mBART consumes a large amount of GPU RAM, especially for fine-tuning. - :class:`MarianMTModel` is usually a better choice for bilingual machine translation. - -Training -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is multilingual it expects the sequences in a different format. A special language id token is added in both the source and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target @@ -76,6 +68,87 @@ the sequences for sequence-to-sequence fine-tuning. assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria" +Overview of MBart-50 +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MBart-50 was introduced in the `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning +` paper by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav +Chaudhary, Jiatao Gu, Angela Fan. MBart-50 is created using the original `mbart-large-cc25` checkpoint by extendeding +its embedding layers with randomly initialized vectors for an extra set of 25 language tokens and then pretrained on 50 +languages. + +According to the abstract + +*Multilingual translation models can be created through multilingual finetuning. Instead of finetuning on one +direction, a pretrained model is finetuned on many directions at the same time. It demonstrates that pretrained models +can be extended to incorporate additional languages without loss of performance. Multilingual finetuning improves on +average 1 BLEU over the strongest baselines (being either multilingual from scratch or bilingual finetuning) while +improving 9.3 BLEU on average over bilingual baselines from scratch.* + + +Training of MBart-50 +_______________________________________________________________________________________________________________________ + +The text format for MBart-50 is slightly different from mBART. For MBart-50 the language id token is used as a prefix +for both source and target text i.e the text format is :obj:`[lang_code] X [eos]`, where :obj:`lang_code` is source +language id for source text and target language id for target text, with :obj:`X` being the source or target text +respectively. + + +MBart-50 has its own tokenizer :class:`~transformers.MBart50Tokenizer`. + +- Supervised training + +.. code-block:: + + from transformers import MBartForConditionalGeneration, MBart50TokenizerFast + + model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") + tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") + + src_text = " UN Chief Says There Is No Military Solution in Syria" + tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + + model_inputs = tokenizer(src_text, return_tensors="pt") + with tokenizer.as_target_tokenizer(): + labels = tokenizer(tgt_text, return_tensors="pt").input_ids + + model(**model_inputs, labels=labels) # forward pass + + +- Generation + + To generate using the mBART-50 multilingual translation models, :obj:`eos_token_id` is used as the + :obj:`decoder_start_token_id` and the target language id is forced as the first generated token. To force the + target language id as the first generated token, pass the `forced_bos_token_id` parameter to the `generate` method. + The following example shows how to translate between Hindi to French and Arabic to English using the + `facebook/mbart-50-large-many-to-many` checkpoint. + +.. code-block:: + + from transformers import MBartForConditionalGeneration, MBart50TokenizerFast + + article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है" + article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا." + + model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") + tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") + + # translate Hindi to French + tokenizer.src_lang = "hi_IN" + encoded_hi = tokenizer(article_hi, return_tensors="pt") + generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"]) + tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + # => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria." + + # translate Arabic to English + tokenizer.src_lang = "ar_AR" + encoded_ar = tokenizer(article_ar, return_tensors="pt") + generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) + tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + # => "The Secretary-General of the United Nations says there is no military solution in Syria." + + MBartConfig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -97,6 +170,20 @@ MBartTokenizerFast :members: +MBart50Tokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBart50Tokenizer + :members: + + +MBart50TokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBart50TokenizerFast + :members: + + MBartModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index c25c58031f..d213267197 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -381,6 +381,15 @@ For the full list, refer to `https://huggingface.co/models ", 0.0), + ("", 0.0), + ("", 0.0), + ("", 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + # fmt: off + vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] + # fmt: on + vocab += [("", 0.0)] + return vocab + + def unk_id(self, proto): + return 3 + + def post_processor(self): + return processors.TemplateProcessing( + single="en_XX $A ", + pair="en_XX $A $B ", + special_tokens=[ + ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")), + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + class XLMRobertaConverter(SpmConverter): def vocab(self, proto): vocab = [ @@ -637,6 +666,7 @@ SLOW_TO_FAST_CONVERTERS = { "LEDTokenizer": RobertaConverter, "LxmertTokenizer": BertConverter, "MBartTokenizer": MBartConverter, + "MBart50Tokenizer": MBart50Converter, "MPNetTokenizer": MPNetConverter, "MobileBertTokenizer": BertConverter, "OpenAIGPTTokenizer": OpenAIGPTConverter, diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index e6f2d53b80..ed4856c451 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -32,9 +32,11 @@ _import_structure = { if is_sentencepiece_available(): _import_structure["tokenization_mbart"] = ["MBartTokenizer"] + _import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"] if is_tokenizers_available(): _import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"] + _import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"] if is_torch_available(): _import_structure["modeling_mbart"] = [ @@ -56,8 +58,10 @@ if TYPE_CHECKING: if is_sentencepiece_available(): from .tokenization_mbart import MBartTokenizer + from .tokenization_mbart50 import MBart50Tokenizer if is_tokenizers_available(): + from .tokenization_mbart50_fast import MBart50TokenizerFast from .tokenization_mbart_fast import MBartTokenizerFast if is_torch_available(): diff --git a/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py b/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py index cb4c7c9d72..eb7f00bf77 100644 --- a/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py +++ b/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py @@ -15,19 +15,49 @@ import argparse import torch +from torch import nn -from transformers import BartForConditionalGeneration, MBartConfig -from transformers.models.bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_ +from transformers import MBartConfig, MBartForConditionalGeneration -def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"): +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "_float_tensor", + "decoder.output_projection.weight", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_mbart_checkpoint_from_disk( + checkpoint_path, hf_config_path="facebook/mbart-large-en-ro", finetuned=False, mbart_50=False +): state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] remove_ignore_keys_(state_dict) vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] + mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size) + if mbart_50 and finetuned: + mbart_config.activation_function = "relu" + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] - model = BartForConditionalGeneration(mbart_config) + model = MBartForConditionalGeneration(mbart_config) model.model.load_state_dict(state_dict) + + if finetuned: + model.lm_head = make_linear_from_emb(model.model.shared) + return model @@ -42,8 +72,12 @@ if __name__ == "__main__": "--hf_config", default="facebook/mbart-large-cc25", type=str, - help="Which huggingface architecture to use: bart-large-xsum", + help="Which huggingface architecture to use: mbart-large", ) + parser.add_argument("--mbart_50", action="store_true", help="whether the model is mMART-50 checkpoint") + parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint") args = parser.parse_args() - model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config) + model = convert_fairseq_mbart_checkpoint_from_disk( + args.fairseq_path, hf_config_path=args.hf_config, finetuned=args.finetuned, mbart_50=args.mbart_50 + ) model.save_pretrained(args.pytorch_dump_folder_path) diff --git a/src/transformers/models/mbart/tokenization_mbart50.py b/src/transformers/models/mbart/tokenization_mbart50.py new file mode 100644 index 0000000000..e6d38a3821 --- /dev/null +++ b/src/transformers/models/mbart/tokenization_mbart50.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from shutil import copyfile +from typing import Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + +_all_mbart50_models = ["facebook/mbart-large-50-one-to-many-mmt"] +SPM_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model" + +# fmt: off +FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"] +# fmt: on + + +class MBart50Tokenizer(PreTrainedTokenizer): + """ + Construct a MBart50 tokenizer. Based on `SentencePiece `__. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + src_lang (:obj:`str`, `optional`): + A string representing the source language. + tgt_lang (:obj:`str`, `optional`): + A string representing the target language. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sequence token. + sep_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + + Examples:: + + >>> from transformers import MBart50Tokenizer + >>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") + >>> src_text = " UN Chief Says There Is No Military Solution in Syria" + >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> model_inputs = tokenizer(src_text, return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(tgt_text, return_tensors="pt").input_ids + >>> # model(**model_inputs, labels=labels) should work + """ + + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = {m: 1024 for m in _all_mbart50_models} + pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart50_models}} + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + src_lang=None, + tgt_lang=None, + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + **kwargs + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + src_lang=src_lang, + tgt_lang=tgt_lang, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + self._additional_special_tokens = list(self.lang_code_to_id.keys()) + + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + @property + def vocab_size(self) -> int: + return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self) -> Dict: + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d: Dict) -> None: + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(self.vocab_file) + + def get_vocab(self) -> Dict: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.EncodeAsPieces(text) + + def _convert_token_to_id(self, token: str) -> int: + """ Converts a token (str) in an id using the vocab. """ + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An MBART-50 sequence has the following format, where ``X`` represents the sequence: + + - ``input_ids`` (for encoder) ``[src_lang_code] X [eos]`` + - ``labels``: (for decoder) ``[tgt_lang_code] X [eos]`` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro_RO", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.set_tgt_lang_special_tokens(self.tgt_lang) + yield + self.set_src_lang_special_tokens(self.src_lang) + + def set_src_lang_special_tokens(self, src_lang: str) -> None: + """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.lang_code_to_id[src_lang] + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. prefix=[tgt_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.lang_code_to_id[tgt_lang] + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] diff --git a/src/transformers/models/mbart/tokenization_mbart50_fast.py b/src/transformers/models/mbart/tokenization_mbart50_fast.py new file mode 100644 index 0000000000..11b21f139e --- /dev/null +++ b/src/transformers/models/mbart/tokenization_mbart50_fast.py @@ -0,0 +1,278 @@ +# coding=utf-8 +# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from shutil import copyfile +from typing import List, Optional, Tuple + +from tokenizers import processors + +from ...file_utils import is_sentencepiece_available +from ...tokenization_utils import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging + + +if is_sentencepiece_available(): + from .tokenization_mbart50 import MBart50Tokenizer +else: + MBart50Tokenizer = None + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + +_all_mbart50_models = ["facebook/mbart-large-50-one-to-many-mmt"] +SPM_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model" +tokenizer_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json" + +# fmt: off +FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"] +# fmt: on + + +class MBart50TokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" MBART tokenizer for mBART-50 (backed by HuggingFace's `tokenizers` library). Based on `BPE + `__. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + src_lang (:obj:`str`, `optional`): + A string representing the source language. + tgt_lang (:obj:`str`, `optional`): + A string representing the target language. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sequence token. + sep_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + + Examples:: + + >>> from transformers import MBart50TokenizerFast + >>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO") + >>> src_text = " UN Chief Says There Is No Military Solution in Syria" + >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> model_inputs = tokenizer(src_text, return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(tgt_text, return_tensors="pt").input_ids + >>> # model(**model_inputs, labels=labels) should work + """ + + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = {m: 1024 for m in _all_mbart50_models} + pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart50_models}} + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = MBart50Tokenizer + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + src_lang=None, + tgt_lang=None, + tokenizer_file=None, + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + **kwargs + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + vocab_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + self.vocab_file = vocab_file + + self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES}) + self.lang_code_to_id = { + lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES + } + + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.tgt_lang = tgt_lang + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + self.set_src_lang_special_tokens(self._src_lang) + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. The special tokens depend on calling set_lang. + + An MBART-50 sequence has the following format, where ``X`` represents the sequence: + + - ``input_ids`` (for encoder) ``[src_lang_code] X [eos]`` + - ``labels``: (for decoder) ``[tgt_lang_code] X [eos]`` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro_RO", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.set_tgt_lang_special_tokens(self.tgt_lang) + yield + self.set_src_lang_special_tokens(self.src_lang) + + def set_src_lang_special_tokens(self, src_lang: str) -> None: + """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.convert_tokens_to_ids(src_lang) + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. prefix=[src_lang_code] and suffix=[eos].""" + self.cur_lang_code_id = self.convert_tokens_to_ids(tgt_lang) + self.prefix_tokens = [self.cur_lang_code_id] + self.suffix_tokens = [self.eos_token_id] + + prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens) + suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens) + + self._tokenizer.post_processor = processors.TemplateProcessing( + single=prefix_tokens_str + ["$A"] + suffix_tokens_str, + pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, + special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), + ) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index bbaddd2cc9..4c3c3c2abd 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -47,6 +47,15 @@ class MarianTokenizer: requires_sentencepiece(self) +class MBart50Tokenizer: + def __init__(self, *args, **kwargs): + requires_sentencepiece(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_sentencepiece(self) + + class MBartTokenizer: def __init__(self, *args, **kwargs): requires_sentencepiece(self) diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 543cabc780..d9a1b8c055 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -164,6 +164,15 @@ class LxmertTokenizerFast: requires_tokenizers(self) +class MBart50TokenizerFast: + def __init__(self, *args, **kwargs): + requires_tokenizers(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tokenizers(self) + + class MBartTokenizerFast: def __init__(self, *args, **kwargs): requires_tokenizers(self) diff --git a/tests/test_tokenization_mbart50.py b/tests/test_tokenization_mbart50.py new file mode 100644 index 0000000000..ddb95d6ec2 --- /dev/null +++ b/tests/test_tokenization_mbart50.py @@ -0,0 +1,200 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available +from transformers.file_utils import is_sentencepiece_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch + +from .test_tokenization_common import TokenizerTesterMixin + + +if is_sentencepiece_available(): + SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + + +if is_torch_available(): + from transformers.models.mbart.modeling_mbart import shift_tokens_right + +EN_CODE = 250004 +RO_CODE = 250020 + + +@require_sentencepiece +@require_tokenizers +class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = MBart50Tokenizer + rust_tokenizer_class = MBart50TokenizerFast + test_rust_tokenizer = True + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def test_full_tokenizer(self): + tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + # fmt: off + [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", "."], + # fmt: on + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [ + value + tokenizer.fairseq_offset + for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4] + ], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + # fmt: off + [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "", "."], + # fmt: on + ) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class MBartOneToManyIntegrationTest(unittest.TestCase): + checkpoint_name = "facebook/mbart-large-50-one-to-many-mmt" + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""", + ] + tgt_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + 'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.', + ] + expected_src_tokens = [EN_CODE, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2] + + @classmethod + def setUpClass(cls): + cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained( + cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO" + ) + cls.pad_token_id = 1 + return cls + + def check_language_codes(self): + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["mr_IN"], 250038) + + def test_tokenizer_batch_encode_plus(self): + ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] + self.assertListEqual(self.expected_src_tokens, ids) + + def test_tokenizer_decode_ignores_language_codes(self): + self.assertIn(RO_CODE, self.tokenizer.all_special_ids) + generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2] + result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) + self.assertEqual(result, expected_romanian) + self.assertNotIn(self.tokenizer.eos_token, result) + + def test_tokenizer_truncation(self): + src_text = ["this is gunna be a long sentence " * 20] + assert isinstance(src_text[0], str) + desired_max_length = 10 + ids = self.tokenizer.prepare_seq2seq_batch( + src_text, + max_length=desired_max_length, + ).input_ids[0] + self.assertEqual(ids[0], EN_CODE) + self.assertEqual(ids[-1], 2) + self.assertEqual(len(ids), desired_max_length) + + def test_mask_token(self): + self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["", "ar_AR"]), [250053, 250001]) + + def test_special_tokens_unaffacted_by_save_load(self): + tmpdirname = tempfile.mkdtemp() + original_special_tokens = self.tokenizer.fairseq_tokens_to_ids + self.tokenizer.save_pretrained(tmpdirname) + new_tok = MBart50Tokenizer.from_pretrained(tmpdirname) + self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) + + # prepare_seq2seq_batch tests below + + @require_torch + def test_batch_fairseq_parity(self): + batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( + self.src_text, tgt_texts=self.tgt_text, return_tensors="pt" + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + + for k in batch: + batch[k] = batch[k].tolist() + # batch = {k: v.tolist() for k,v in batch.items()} + # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 + # batch.decoder_inputs_ids[0][0] == + assert batch.input_ids[1][0] == EN_CODE + assert batch.input_ids[1][-1] == 2 + assert batch.labels[1][0] == RO_CODE + assert batch.labels[1][-1] == 2 + assert batch.decoder_input_ids[1][:2] == [2, RO_CODE] + + @require_torch + def test_tokenizer_prepare_seq2seq_batch(self): + batch = self.tokenizer.prepare_seq2seq_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt" + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 14), batch.input_ids.shape) + self.assertEqual((2, 14), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(self.expected_src_tokens, result) + self.assertEqual(2, batch.decoder_input_ids[0, 0]) # decoder_start_token_id + # Test that special tokens are reset + self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE]) + self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) + + def test_seq2seq_max_target_length(self): + batch = self.tokenizer.prepare_seq2seq_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt" + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 10) + # max_target_length will default to max_length if not specified + batch = self.tokenizer.prepare_seq2seq_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt" + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 3)