Add mBART-50 (#10154)
* add tokenizer for mBART-50 * update tokenizers * make src_lang and tgt_lang optional * update tokenizer test * add setter * update docs * update conversion script * update docs * update conversion script * update tokenizer * update test * update docs * doc * address Sylvain's suggestions * fix test * fix formatting * nits
This commit is contained in:
@@ -217,6 +217,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
|||||||
1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||||
1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
||||||
1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||||
|
1. **[MBart-50](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||||
1. **[MPNet](https://huggingface.co/transformers/model_doc/mpnet.html)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
1. **[MPNet](https://huggingface.co/transformers/model_doc/mpnet.html)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||||
1. **[MT5](https://huggingface.co/transformers/model_doc/mt5.html)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
1. **[MT5](https://huggingface.co/transformers/model_doc/mt5.html)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||||
1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777)> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777)> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||||
|
|||||||
@@ -161,48 +161,51 @@ and conversion utilities for the following models:
|
|||||||
26. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
26. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||||
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
||||||
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||||
27. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
27. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||||
|
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
|
||||||
|
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||||
|
28. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||||
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
||||||
Jianfeng Lu, Tie-Yan Liu.
|
Jianfeng Lu, Tie-Yan Liu.
|
||||||
28. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
29. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||||
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
||||||
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||||
29. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
30. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||||
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
||||||
Mohammad Saleh and Peter J. Liu.
|
Mohammad Saleh and Peter J. Liu.
|
||||||
30. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
31. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||||
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
||||||
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||||
31. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
32. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||||
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||||
32. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
33. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||||
33. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
34. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||||
Krishna, and Kurt W. Keutzer.
|
Krishna, and Kurt W. Keutzer.
|
||||||
34. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
35. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||||
35. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
36. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||||
36. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
37. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||||
37. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
38. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||||
38. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
39. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||||
39. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
40. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||||
40. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
41. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||||
Zettlemoyer and Veselin Stoyanov.
|
Zettlemoyer and Veselin Stoyanov.
|
||||||
41. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
42. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,14 @@
|
|||||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
specific language governing permissions and limitations under the License.
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
MBart
|
MBart and MBart-50
|
||||||
-----------------------------------------------------------------------------------------------------------------------
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
**DISCLAIMER:** If you see something strange, file a `Github Issue
|
**DISCLAIMER:** If you see something strange, file a `Github Issue
|
||||||
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||||
@patrickvonplaten
|
@patrickvonplaten
|
||||||
|
|
||||||
Overview
|
Overview of MBart
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation
|
The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation
|
||||||
@@ -31,17 +31,9 @@ on the encoder, decoder, or reconstructing parts of the text.
|
|||||||
|
|
||||||
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`__
|
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`__
|
||||||
|
|
||||||
Examples
|
Training of MBart
|
||||||
_______________________________________________________________________________________________________________________
|
_______________________________________________________________________________________________________________________
|
||||||
|
|
||||||
- Examples and scripts for fine-tuning mBART and other models for sequence to sequence tasks can be found in
|
|
||||||
:prefix_link:`examples/seq2seq/ <examples/seq2seq/README.md>`.
|
|
||||||
- Given the large embeddings table, mBART consumes a large amount of GPU RAM, especially for fine-tuning.
|
|
||||||
:class:`MarianMTModel` is usually a better choice for bilingual machine translation.
|
|
||||||
|
|
||||||
Training
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is
|
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is
|
||||||
multilingual it expects the sequences in a different format. A special language id token is added in both the source
|
multilingual it expects the sequences in a different format. A special language id token is added in both the source
|
||||||
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target
|
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target
|
||||||
@@ -76,6 +68,87 @@ the sequences for sequence-to-sequence fine-tuning.
|
|||||||
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||||
|
|
||||||
|
|
||||||
|
Overview of MBart-50
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
MBart-50 was introduced in the `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning
|
||||||
|
<https://arxiv.org/abs/2008.00401>` paper by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav
|
||||||
|
Chaudhary, Jiatao Gu, Angela Fan. MBart-50 is created using the original `mbart-large-cc25` checkpoint by extendeding
|
||||||
|
its embedding layers with randomly initialized vectors for an extra set of 25 language tokens and then pretrained on 50
|
||||||
|
languages.
|
||||||
|
|
||||||
|
According to the abstract
|
||||||
|
|
||||||
|
*Multilingual translation models can be created through multilingual finetuning. Instead of finetuning on one
|
||||||
|
direction, a pretrained model is finetuned on many directions at the same time. It demonstrates that pretrained models
|
||||||
|
can be extended to incorporate additional languages without loss of performance. Multilingual finetuning improves on
|
||||||
|
average 1 BLEU over the strongest baselines (being either multilingual from scratch or bilingual finetuning) while
|
||||||
|
improving 9.3 BLEU on average over bilingual baselines from scratch.*
|
||||||
|
|
||||||
|
|
||||||
|
Training of MBart-50
|
||||||
|
_______________________________________________________________________________________________________________________
|
||||||
|
|
||||||
|
The text format for MBart-50 is slightly different from mBART. For MBart-50 the language id token is used as a prefix
|
||||||
|
for both source and target text i.e the text format is :obj:`[lang_code] X [eos]`, where :obj:`lang_code` is source
|
||||||
|
language id for source text and target language id for target text, with :obj:`X` being the source or target text
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
|
||||||
|
MBart-50 has its own tokenizer :class:`~transformers.MBart50Tokenizer`.
|
||||||
|
|
||||||
|
- Supervised training
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
||||||
|
|
||||||
|
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
|
||||||
|
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
|
||||||
|
|
||||||
|
src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||||
|
tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||||
|
|
||||||
|
model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||||
|
with tokenizer.as_target_tokenizer():
|
||||||
|
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
model(**model_inputs, labels=labels) # forward pass
|
||||||
|
|
||||||
|
|
||||||
|
- Generation
|
||||||
|
|
||||||
|
To generate using the mBART-50 multilingual translation models, :obj:`eos_token_id` is used as the
|
||||||
|
:obj:`decoder_start_token_id` and the target language id is forced as the first generated token. To force the
|
||||||
|
target language id as the first generated token, pass the `forced_bos_token_id` parameter to the `generate` method.
|
||||||
|
The following example shows how to translate between Hindi to French and Arabic to English using the
|
||||||
|
`facebook/mbart-50-large-many-to-many` checkpoint.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
||||||
|
|
||||||
|
article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
|
||||||
|
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
|
||||||
|
|
||||||
|
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||||
|
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||||
|
|
||||||
|
# translate Hindi to French
|
||||||
|
tokenizer.src_lang = "hi_IN"
|
||||||
|
encoded_hi = tokenizer(article_hi, return_tensors="pt")
|
||||||
|
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])
|
||||||
|
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
# => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."
|
||||||
|
|
||||||
|
# translate Arabic to English
|
||||||
|
tokenizer.src_lang = "ar_AR"
|
||||||
|
encoded_ar = tokenizer(article_ar, return_tensors="pt")
|
||||||
|
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
|
||||||
|
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
# => "The Secretary-General of the United Nations says there is no military solution in Syria."
|
||||||
|
|
||||||
|
|
||||||
MBartConfig
|
MBartConfig
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -97,6 +170,20 @@ MBartTokenizerFast
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MBart50Tokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MBart50Tokenizer
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MBart50TokenizerFast
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MBart50TokenizerFast
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
MBartModel
|
MBartModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -381,6 +381,15 @@ For the full list, refer to `https://huggingface.co/models <https://huggingface.
|
|||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``facebook/mbart-large-en-ro`` | | 24-layer, 1024-hidden, 16-heads, 610M parameters |
|
| | ``facebook/mbart-large-en-ro`` | | 24-layer, 1024-hidden, 16-heads, 610M parameters |
|
||||||
| | | | mbart-large-cc25 model finetuned on WMT english romanian translation. |
|
| | | | mbart-large-cc25 model finetuned on WMT english romanian translation. |
|
||||||
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| | ``facebook/mbart-large-50`` | | 24-layer, 1024-hidden, 16-heads, |
|
||||||
|
| | | | mBART model trained on 50 languages' monolingual corpus. |
|
||||||
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| | ``facebook/mbart-large-50-one-to-many-mmt`` | | 24-layer, 1024-hidden, 16-heads, |
|
||||||
|
| | | | mbart-50-large model finetuned for one (English) to many multilingual machine translation covering 50 languages. |
|
||||||
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
|
| | ``facebook/mbart-large-50-many-to-many-mmt`` | | 24-layer, 1024-hidden, 16-heads, |
|
||||||
|
| | | | mbart-50-large model finetuned for many to many multilingual machine translation covering 50 languages. |
|
||||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| Lxmert | ``lxmert-base-uncased`` | | 9-language layers, 9-relationship layers, and 12-cross-modality layers |
|
| Lxmert | ``lxmert-base-uncased`` | | 9-language layers, 9-relationship layers, and 12-cross-modality layers |
|
||||||
| | | | 768-hidden, 12-heads (for each layer) ~ 228M parameters |
|
| | | | 768-hidden, 12-heads (for each layer) ~ 228M parameters |
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ if is_sentencepiece_available():
|
|||||||
_import_structure["models.camembert"].append("CamembertTokenizer")
|
_import_structure["models.camembert"].append("CamembertTokenizer")
|
||||||
_import_structure["models.marian"].append("MarianTokenizer")
|
_import_structure["models.marian"].append("MarianTokenizer")
|
||||||
_import_structure["models.mbart"].append("MBartTokenizer")
|
_import_structure["models.mbart"].append("MBartTokenizer")
|
||||||
|
_import_structure["models.mbart"].append("MBart50Tokenizer")
|
||||||
_import_structure["models.mt5"].append("MT5Tokenizer")
|
_import_structure["models.mt5"].append("MT5Tokenizer")
|
||||||
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
||||||
_import_structure["models.reformer"].append("ReformerTokenizer")
|
_import_structure["models.reformer"].append("ReformerTokenizer")
|
||||||
@@ -296,6 +297,7 @@ if is_tokenizers_available():
|
|||||||
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
||||||
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
||||||
_import_structure["models.mbart"].append("MBartTokenizerFast")
|
_import_structure["models.mbart"].append("MBartTokenizerFast")
|
||||||
|
_import_structure["models.mbart"].append("MBart50TokenizerFast")
|
||||||
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
|
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
|
||||||
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
|
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
|
||||||
_import_structure["models.mt5"].append("MT5TokenizerFast")
|
_import_structure["models.mt5"].append("MT5TokenizerFast")
|
||||||
@@ -1391,7 +1393,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.bert_generation import BertGenerationTokenizer
|
from .models.bert_generation import BertGenerationTokenizer
|
||||||
from .models.camembert import CamembertTokenizer
|
from .models.camembert import CamembertTokenizer
|
||||||
from .models.marian import MarianTokenizer
|
from .models.marian import MarianTokenizer
|
||||||
from .models.mbart import MBartTokenizer
|
from .models.mbart import MBart50Tokenizer, MBartTokenizer
|
||||||
from .models.mt5 import MT5Tokenizer
|
from .models.mt5 import MT5Tokenizer
|
||||||
from .models.pegasus import PegasusTokenizer
|
from .models.pegasus import PegasusTokenizer
|
||||||
from .models.reformer import ReformerTokenizer
|
from .models.reformer import ReformerTokenizer
|
||||||
@@ -1419,7 +1421,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.led import LEDTokenizerFast
|
from .models.led import LEDTokenizerFast
|
||||||
from .models.longformer import LongformerTokenizerFast
|
from .models.longformer import LongformerTokenizerFast
|
||||||
from .models.lxmert import LxmertTokenizerFast
|
from .models.lxmert import LxmertTokenizerFast
|
||||||
from .models.mbart import MBartTokenizerFast
|
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
|
||||||
from .models.mobilebert import MobileBertTokenizerFast
|
from .models.mobilebert import MobileBertTokenizerFast
|
||||||
from .models.mpnet import MPNetTokenizerFast
|
from .models.mpnet import MPNetTokenizerFast
|
||||||
from .models.mt5 import MT5TokenizerFast
|
from .models.mt5 import MT5TokenizerFast
|
||||||
|
|||||||
@@ -500,6 +500,35 @@ class MBartConverter(SpmConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MBart50Converter(SpmConverter):
|
||||||
|
def vocab(self, proto):
|
||||||
|
vocab = [
|
||||||
|
("<s>", 0.0),
|
||||||
|
("<pad>", 0.0),
|
||||||
|
("</s>", 0.0),
|
||||||
|
("<unk>", 0.0),
|
||||||
|
]
|
||||||
|
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
||||||
|
# fmt: off
|
||||||
|
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
|
||||||
|
# fmt: on
|
||||||
|
vocab += [("<mask>", 0.0)]
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def unk_id(self, proto):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
def post_processor(self):
|
||||||
|
return processors.TemplateProcessing(
|
||||||
|
single="en_XX $A </s>",
|
||||||
|
pair="en_XX $A $B </s>",
|
||||||
|
special_tokens=[
|
||||||
|
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
|
||||||
|
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaConverter(SpmConverter):
|
class XLMRobertaConverter(SpmConverter):
|
||||||
def vocab(self, proto):
|
def vocab(self, proto):
|
||||||
vocab = [
|
vocab = [
|
||||||
@@ -637,6 +666,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
|||||||
"LEDTokenizer": RobertaConverter,
|
"LEDTokenizer": RobertaConverter,
|
||||||
"LxmertTokenizer": BertConverter,
|
"LxmertTokenizer": BertConverter,
|
||||||
"MBartTokenizer": MBartConverter,
|
"MBartTokenizer": MBartConverter,
|
||||||
|
"MBart50Tokenizer": MBart50Converter,
|
||||||
"MPNetTokenizer": MPNetConverter,
|
"MPNetTokenizer": MPNetConverter,
|
||||||
"MobileBertTokenizer": BertConverter,
|
"MobileBertTokenizer": BertConverter,
|
||||||
"OpenAIGPTTokenizer": OpenAIGPTConverter,
|
"OpenAIGPTTokenizer": OpenAIGPTConverter,
|
||||||
|
|||||||
@@ -32,9 +32,11 @@ _import_structure = {
|
|||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
|
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
|
||||||
|
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
||||||
|
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_mbart"] = [
|
_import_structure["modeling_mbart"] = [
|
||||||
@@ -56,8 +58,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_mbart import MBartTokenizer
|
from .tokenization_mbart import MBartTokenizer
|
||||||
|
from .tokenization_mbart50 import MBart50Tokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
from .tokenization_mbart50_fast import MBart50TokenizerFast
|
||||||
from .tokenization_mbart_fast import MBartTokenizerFast
|
from .tokenization_mbart_fast import MBartTokenizerFast
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|||||||
@@ -15,19 +15,49 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from transformers import BartForConditionalGeneration, MBartConfig
|
from transformers import MBartConfig, MBartForConditionalGeneration
|
||||||
from transformers.models.bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
|
||||||
|
|
||||||
|
|
||||||
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
def remove_ignore_keys_(state_dict):
|
||||||
|
ignore_keys = [
|
||||||
|
"encoder.version",
|
||||||
|
"decoder.version",
|
||||||
|
"model.encoder.version",
|
||||||
|
"model.decoder.version",
|
||||||
|
"_float_tensor",
|
||||||
|
"decoder.output_projection.weight",
|
||||||
|
]
|
||||||
|
for k in ignore_keys:
|
||||||
|
state_dict.pop(k, None)
|
||||||
|
|
||||||
|
|
||||||
|
def make_linear_from_emb(emb):
|
||||||
|
vocab_size, emb_size = emb.weight.shape
|
||||||
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||||
|
lin_layer.weight.data = emb.weight.data
|
||||||
|
return lin_layer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_fairseq_mbart_checkpoint_from_disk(
|
||||||
|
checkpoint_path, hf_config_path="facebook/mbart-large-en-ro", finetuned=False, mbart_50=False
|
||||||
|
):
|
||||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||||
remove_ignore_keys_(state_dict)
|
remove_ignore_keys_(state_dict)
|
||||||
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
||||||
|
|
||||||
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
|
mbart_config = MBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size)
|
||||||
|
if mbart_50 and finetuned:
|
||||||
|
mbart_config.activation_function = "relu"
|
||||||
|
|
||||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||||
model = BartForConditionalGeneration(mbart_config)
|
model = MBartForConditionalGeneration(mbart_config)
|
||||||
model.model.load_state_dict(state_dict)
|
model.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
if finetuned:
|
||||||
|
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -42,8 +72,12 @@ if __name__ == "__main__":
|
|||||||
"--hf_config",
|
"--hf_config",
|
||||||
default="facebook/mbart-large-cc25",
|
default="facebook/mbart-large-cc25",
|
||||||
type=str,
|
type=str,
|
||||||
help="Which huggingface architecture to use: bart-large-xsum",
|
help="Which huggingface architecture to use: mbart-large",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--mbart_50", action="store_true", help="whether the model is mMART-50 checkpoint")
|
||||||
|
parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model = convert_fairseq_mbart_checkpoint_from_disk(args.fairseq_path, hf_config_path=args.hf_config)
|
model = convert_fairseq_mbart_checkpoint_from_disk(
|
||||||
|
args.fairseq_path, hf_config_path=args.hf_config, finetuned=args.finetuned, mbart_50=args.mbart_50
|
||||||
|
)
|
||||||
model.save_pretrained(args.pytorch_dump_folder_path)
|
model.save_pretrained(args.pytorch_dump_folder_path)
|
||||||
|
|||||||
308
src/transformers/models/mbart/tokenization_mbart50.py
Normal file
308
src/transformers/models/mbart/tokenization_mbart50.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
SPIECE_UNDERLINE = "▁"
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
|
||||||
|
|
||||||
|
_all_mbart50_models = ["facebook/mbart-large-50-one-to-many-mmt"]
|
||||||
|
SPM_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class MBart50Tokenizer(PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Construct a MBart50 tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__.
|
||||||
|
|
||||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
||||||
|
Users should refer to this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (:obj:`str`):
|
||||||
|
Path to the vocabulary file.
|
||||||
|
src_lang (:obj:`str`, `optional`):
|
||||||
|
A string representing the source language.
|
||||||
|
tgt_lang (:obj:`str`, `optional`):
|
||||||
|
A string representing the target language.
|
||||||
|
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
sep_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||||
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||||
|
token of a sequence built with special tokens.
|
||||||
|
cls_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
|
||||||
|
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||||
|
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask>"`):
|
||||||
|
The token used for masking values. This is the token used when training this model with masked language
|
||||||
|
modeling. This is the token which the model will try to predict.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import MBart50Tokenizer
|
||||||
|
>>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
|
||||||
|
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||||
|
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||||
|
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||||
|
>>> with tokenizer.as_target_tokenizer():
|
||||||
|
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||||
|
>>> # model(**model_inputs, labels=labels) should work
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
max_model_input_sizes = {m: 1024 for m in _all_mbart50_models}
|
||||||
|
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart50_models}}
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
|
prefix_tokens: List[int] = []
|
||||||
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
src_lang=None,
|
||||||
|
tgt_lang=None,
|
||||||
|
eos_token="</s>",
|
||||||
|
sep_token="</s>",
|
||||||
|
cls_token="<s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
mask_token="<mask>",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
eos_token=eos_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sp_model = spm.SentencePieceProcessor()
|
||||||
|
self.sp_model.Load(str(vocab_file))
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
|
||||||
|
# Original fairseq vocab and spm vocab must be "aligned":
|
||||||
|
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
|
||||||
|
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
|
||||||
|
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
|
||||||
|
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
|
||||||
|
|
||||||
|
# Mimic fairseq token-to-id alignment for the first 4 token
|
||||||
|
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
|
||||||
|
|
||||||
|
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
||||||
|
self.fairseq_offset = 1
|
||||||
|
|
||||||
|
self.sp_model_size = len(self.sp_model)
|
||||||
|
self.lang_code_to_id = {
|
||||||
|
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||||
|
}
|
||||||
|
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||||
|
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||||
|
|
||||||
|
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||||
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
|
|
||||||
|
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||||
|
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def src_lang(self) -> str:
|
||||||
|
return self._src_lang
|
||||||
|
|
||||||
|
@src_lang.setter
|
||||||
|
def src_lang(self, new_src_lang: str) -> None:
|
||||||
|
self._src_lang = new_src_lang
|
||||||
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["sp_model"] = None
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, d: Dict) -> None:
|
||||||
|
self.__dict__ = d
|
||||||
|
self.sp_model = spm.SentencePieceProcessor()
|
||||||
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
|
def get_vocab(self) -> Dict:
|
||||||
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
|
vocab.update(self.added_tokens_encoder)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
|
return self.sp_model.EncodeAsPieces(text)
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token: str) -> int:
|
||||||
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
if token in self.fairseq_tokens_to_ids:
|
||||||
|
return self.fairseq_tokens_to_ids[token]
|
||||||
|
spm_id = self.sp_model.PieceToId(token)
|
||||||
|
|
||||||
|
# Need to return unknown token if the SP model returned 0
|
||||||
|
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
if index in self.fairseq_ids_to_tokens:
|
||||||
|
return self.fairseq_ids_to_tokens[index]
|
||||||
|
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||||
|
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
return
|
||||||
|
out_vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||||
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
|
return (out_vocab_file,)
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if already_has_special_tokens:
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
|
"ids is already formatted with special tokens for the model."
|
||||||
|
)
|
||||||
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||||
|
prefix_ones = [1] * len(self.prefix_tokens)
|
||||||
|
suffix_ones = [1] * len(self.suffix_tokens)
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
|
adding special tokens. An MBART-50 sequence has the following format, where ``X`` represents the sequence:
|
||||||
|
|
||||||
|
- ``input_ids`` (for encoder) ``[src_lang_code] X [eos]``
|
||||||
|
- ``labels``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||||
|
|
||||||
|
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
||||||
|
separator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||||
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def prepare_seq2seq_batch(
|
||||||
|
self,
|
||||||
|
src_texts: List[str],
|
||||||
|
src_lang: str = "en_XX",
|
||||||
|
tgt_texts: Optional[List[str]] = None,
|
||||||
|
tgt_lang: str = "ro_RO",
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchEncoding:
|
||||||
|
self.src_lang = src_lang
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_target_tokenizer(self):
|
||||||
|
"""
|
||||||
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
|
"""
|
||||||
|
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
|
yield
|
||||||
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
|
|
||||||
|
def set_src_lang_special_tokens(self, src_lang: str) -> None:
|
||||||
|
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
|
||||||
|
self.cur_lang_code_id = self.lang_code_to_id[src_lang]
|
||||||
|
self.prefix_tokens = [self.cur_lang_code_id]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
|
def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
|
||||||
|
"""Reset the special tokens to the target language setting. prefix=[tgt_lang_code] and suffix=[eos]."""
|
||||||
|
self.cur_lang_code_id = self.lang_code_to_id[tgt_lang]
|
||||||
|
self.prefix_tokens = [self.cur_lang_code_id]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
278
src/transformers/models/mbart/tokenization_mbart50_fast.py
Normal file
278
src/transformers/models/mbart/tokenization_mbart50_fast.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from tokenizers import processors
|
||||||
|
|
||||||
|
from ...file_utils import is_sentencepiece_available
|
||||||
|
from ...tokenization_utils import AddedToken, BatchEncoding
|
||||||
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_sentencepiece_available():
|
||||||
|
from .tokenization_mbart50 import MBart50Tokenizer
|
||||||
|
else:
|
||||||
|
MBart50Tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
||||||
|
|
||||||
|
_all_mbart50_models = ["facebook/mbart-large-50-one-to-many-mmt"]
|
||||||
|
SPM_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
||||||
|
tokenizer_URL = "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/tokenizer.json"
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
||||||
|
"""
|
||||||
|
Construct a "fast" MBART tokenizer for mBART-50 (backed by HuggingFace's `tokenizers` library). Based on `BPE
|
||||||
|
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.
|
||||||
|
|
||||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
|
||||||
|
methods. Users should refer to this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (:obj:`str`):
|
||||||
|
Path to the vocabulary file.
|
||||||
|
src_lang (:obj:`str`, `optional`):
|
||||||
|
A string representing the source language.
|
||||||
|
tgt_lang (:obj:`str`, `optional`):
|
||||||
|
A string representing the target language.
|
||||||
|
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
sep_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||||
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||||
|
token of a sequence built with special tokens.
|
||||||
|
cls_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
|
||||||
|
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||||
|
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask>"`):
|
||||||
|
The token used for masking values. This is the token used when training this model with masked language
|
||||||
|
modeling. This is the token which the model will try to predict.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import MBart50TokenizerFast
|
||||||
|
>>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
|
||||||
|
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||||
|
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||||
|
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||||
|
>>> with tokenizer.as_target_tokenizer():
|
||||||
|
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||||
|
>>> # model(**model_inputs, labels=labels) should work
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
max_model_input_sizes = {m: 1024 for m in _all_mbart50_models}
|
||||||
|
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart50_models}}
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
slow_tokenizer_class = MBart50Tokenizer
|
||||||
|
|
||||||
|
prefix_tokens: List[int] = []
|
||||||
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
src_lang=None,
|
||||||
|
tgt_lang=None,
|
||||||
|
tokenizer_file=None,
|
||||||
|
eos_token="</s>",
|
||||||
|
sep_token="</s>",
|
||||||
|
cls_token="<s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
mask_token="<mask>",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
vocab_file,
|
||||||
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
eos_token=eos_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
|
||||||
|
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
|
||||||
|
self.lang_code_to_id = {
|
||||||
|
lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
|
||||||
|
}
|
||||||
|
|
||||||
|
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||||
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def src_lang(self) -> str:
|
||||||
|
return self._src_lang
|
||||||
|
|
||||||
|
@src_lang.setter
|
||||||
|
def src_lang(self, new_src_lang: str) -> None:
|
||||||
|
self._src_lang = new_src_lang
|
||||||
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of ids.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if already_has_special_tokens:
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
|
"ids is already formatted with special tokens for the model."
|
||||||
|
)
|
||||||
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||||
|
prefix_ones = [1] * len(self.prefix_tokens)
|
||||||
|
suffix_ones = [1] * len(self.suffix_tokens)
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
|
adding special tokens. The special tokens depend on calling set_lang.
|
||||||
|
|
||||||
|
An MBART-50 sequence has the following format, where ``X`` represents the sequence:
|
||||||
|
|
||||||
|
- ``input_ids`` (for encoder) ``[src_lang_code] X [eos]``
|
||||||
|
- ``labels``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||||
|
|
||||||
|
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
||||||
|
separator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||||
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def prepare_seq2seq_batch(
|
||||||
|
self,
|
||||||
|
src_texts: List[str],
|
||||||
|
src_lang: str = "en_XX",
|
||||||
|
tgt_texts: Optional[List[str]] = None,
|
||||||
|
tgt_lang: str = "ro_RO",
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchEncoding:
|
||||||
|
self.src_lang = src_lang
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_target_tokenizer(self):
|
||||||
|
"""
|
||||||
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
|
"""
|
||||||
|
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
|
yield
|
||||||
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
|
|
||||||
|
def set_src_lang_special_tokens(self, src_lang: str) -> None:
|
||||||
|
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
|
||||||
|
self.cur_lang_code_id = self.convert_tokens_to_ids(src_lang)
|
||||||
|
self.prefix_tokens = [self.cur_lang_code_id]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
|
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
||||||
|
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
||||||
|
|
||||||
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||||
|
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
|
||||||
|
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
|
||||||
|
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
|
||||||
|
"""Reset the special tokens to the target language setting. prefix=[src_lang_code] and suffix=[eos]."""
|
||||||
|
self.cur_lang_code_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
|
self.prefix_tokens = [self.cur_lang_code_id]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
|
prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
|
||||||
|
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
|
||||||
|
|
||||||
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||||
|
single=prefix_tokens_str + ["$A"] + suffix_tokens_str,
|
||||||
|
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
|
||||||
|
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
return
|
||||||
|
out_vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||||
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
|
return (out_vocab_file,)
|
||||||
@@ -47,6 +47,15 @@ class MarianTokenizer:
|
|||||||
requires_sentencepiece(self)
|
requires_sentencepiece(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MBart50Tokenizer:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_sentencepiece(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_sentencepiece(self)
|
||||||
|
|
||||||
|
|
||||||
class MBartTokenizer:
|
class MBartTokenizer:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_sentencepiece(self)
|
requires_sentencepiece(self)
|
||||||
|
|||||||
@@ -164,6 +164,15 @@ class LxmertTokenizerFast:
|
|||||||
requires_tokenizers(self)
|
requires_tokenizers(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MBart50TokenizerFast:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tokenizers(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tokenizers(self)
|
||||||
|
|
||||||
|
|
||||||
class MBartTokenizerFast:
|
class MBartTokenizerFast:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_tokenizers(self)
|
requires_tokenizers(self)
|
||||||
|
|||||||
200
tests/test_tokenization_mbart50.py
Normal file
200
tests/test_tokenization_mbart50.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
|
||||||
|
from transformers.file_utils import is_sentencepiece_available
|
||||||
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||||
|
|
||||||
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_sentencepiece_available():
|
||||||
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.models.mbart.modeling_mbart import shift_tokens_right
|
||||||
|
|
||||||
|
EN_CODE = 250004
|
||||||
|
RO_CODE = 250020
|
||||||
|
|
||||||
|
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
tokenizer_class = MBart50Tokenizer
|
||||||
|
rust_tokenizer_class = MBart50TokenizerFast
|
||||||
|
test_rust_tokenizer = True
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# We have a SentencePiece fixture for testing
|
||||||
|
tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)
|
||||||
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("This is a test")
|
||||||
|
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens),
|
||||||
|
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||||
|
self.assertListEqual(
|
||||||
|
tokens,
|
||||||
|
# fmt: off
|
||||||
|
[SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", "."],
|
||||||
|
# fmt: on
|
||||||
|
)
|
||||||
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(
|
||||||
|
ids,
|
||||||
|
[
|
||||||
|
value + tokenizer.fairseq_offset
|
||||||
|
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||||
|
self.assertListEqual(
|
||||||
|
back_tokens,
|
||||||
|
# fmt: off
|
||||||
|
[SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "<unk>", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "<unk>", "."],
|
||||||
|
# fmt: on
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_sentencepiece
|
||||||
|
@require_tokenizers
|
||||||
|
class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||||
|
checkpoint_name = "facebook/mbart-large-50-one-to-many-mmt"
|
||||||
|
src_text = [
|
||||||
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
|
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
||||||
|
]
|
||||||
|
tgt_text = [
|
||||||
|
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||||
|
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
||||||
|
]
|
||||||
|
expected_src_tokens = [EN_CODE, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained(
|
||||||
|
cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
|
||||||
|
)
|
||||||
|
cls.pad_token_id = 1
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def check_language_codes(self):
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["mr_IN"], 250038)
|
||||||
|
|
||||||
|
def test_tokenizer_batch_encode_plus(self):
|
||||||
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, ids)
|
||||||
|
|
||||||
|
def test_tokenizer_decode_ignores_language_codes(self):
|
||||||
|
self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
|
||||||
|
generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
|
||||||
|
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
||||||
|
self.assertEqual(result, expected_romanian)
|
||||||
|
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||||
|
|
||||||
|
def test_tokenizer_truncation(self):
|
||||||
|
src_text = ["this is gunna be a long sentence " * 20]
|
||||||
|
assert isinstance(src_text[0], str)
|
||||||
|
desired_max_length = 10
|
||||||
|
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
src_text,
|
||||||
|
max_length=desired_max_length,
|
||||||
|
).input_ids[0]
|
||||||
|
self.assertEqual(ids[0], EN_CODE)
|
||||||
|
self.assertEqual(ids[-1], 2)
|
||||||
|
self.assertEqual(len(ids), desired_max_length)
|
||||||
|
|
||||||
|
def test_mask_token(self):
|
||||||
|
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "ar_AR"]), [250053, 250001])
|
||||||
|
|
||||||
|
def test_special_tokens_unaffacted_by_save_load(self):
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
|
||||||
|
self.tokenizer.save_pretrained(tmpdirname)
|
||||||
|
new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
|
||||||
|
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||||
|
|
||||||
|
# prepare_seq2seq_batch tests below
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_batch_fairseq_parity(self):
|
||||||
|
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
for k in batch:
|
||||||
|
batch[k] = batch[k].tolist()
|
||||||
|
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||||
|
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||||
|
# batch.decoder_inputs_ids[0][0] ==
|
||||||
|
assert batch.input_ids[1][0] == EN_CODE
|
||||||
|
assert batch.input_ids[1][-1] == 2
|
||||||
|
assert batch.labels[1][0] == RO_CODE
|
||||||
|
assert batch.labels[1][-1] == 2
|
||||||
|
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_tokenizer_prepare_seq2seq_batch(self):
|
||||||
|
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||||
|
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||||
|
result = batch.input_ids.tolist()[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, result)
|
||||||
|
self.assertEqual(2, batch.decoder_input_ids[0, 0]) # decoder_start_token_id
|
||||||
|
# Test that special tokens are reset
|
||||||
|
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
|
||||||
|
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||||
|
|
||||||
|
def test_seq2seq_max_target_length(self):
|
||||||
|
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||||
|
# max_target_length will default to max_length if not specified
|
||||||
|
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||||
Reference in New Issue
Block a user