From 30646a0a3c5d547b2dc87c864ce263ddad37a727 Mon Sep 17 00:00:00 2001 From: Ryokan RI Date: Tue, 7 Dec 2021 14:25:28 +0900 Subject: [PATCH] Add mLUKE (#14640) * implement MLukeTokenizer and LukeForMaskedLM * update tests * update docs * add LukeForMaskedLM to check_repo.py * update README * fix test and specify the entity pad id in tokenization_(m)luke * fix EntityPredictionHeadTransform --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/index.mdx | 1 + docs/source/main_classes/tokenizer.rst | 2 +- docs/source/model_doc/luke.rst | 6 + docs/source/model_doc/mluke.rst | 66 + docs/source/multilingual.rst | 16 +- src/transformers/__init__.py | 5 + src/transformers/models/__init__.py | 1 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/luke/__init__.py | 2 + src/transformers/models/luke/modeling_luke.py | 241 ++- .../models/luke/tokenization_luke.py | 282 +-- src/transformers/models/mluke/__init__.py | 38 + ..._original_pytorch_checkpoint_to_pytorch.py | 228 +++ .../models/mluke/tokenization_mluke.py | 1606 +++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 12 + .../utils/dummy_sentencepiece_objects.py | 9 + tests/test_modeling_luke.py | 73 + tests/test_tokenization_luke.py | 81 +- tests/test_tokenization_mluke.py | 666 +++++++ utils/check_repo.py | 1 + 24 files changed, 3107 insertions(+), 234 deletions(-) create mode 100644 docs/source/model_doc/mluke.rst create mode 100644 src/transformers/models/mluke/__init__.py create mode 100644 src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/mluke/tokenization_mluke.py create mode 100644 tests/test_tokenization_mluke.py diff --git a/README.md b/README.md index 79e25ece94..878d854467 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. +1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. diff --git a/README_ko.md b/README_ko.md index 0ef945c7cb..a47e5c171d 100644 --- a/README_ko.md +++ b/README_ko.md @@ -261,6 +261,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. +1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/README_zh-hans.md b/README_zh-hans.md index c54bc9e8ed..57ec6ca17e 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -285,6 +285,7 @@ conda install -c huggingface transformers 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 由 Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 发布。 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。 +1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (来自 Studio Ousia) 伴随论文 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 由 Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 发布。 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 287c0c2800..f548b4f4a4 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -297,6 +297,7 @@ conda install -c huggingface transformers 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. 1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. +1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 24c91bc798..4a7c597d08 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -135,6 +135,7 @@ conversion utilities for the following models. 1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. +1. **[mLUKE](model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. 1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. diff --git a/docs/source/main_classes/tokenizer.rst b/docs/source/main_classes/tokenizer.rst index 18798e9b49..7b591a5e9a 100644 --- a/docs/source/main_classes/tokenizer.rst +++ b/docs/source/main_classes/tokenizer.rst @@ -20,7 +20,7 @@ Rust library `tokenizers `__. The "Fa 1. a significant speed-up in particular when doing batched tokenization and 2. additional methods to map between the original string (character and words) and the token space (e.g. getting the index of the token comprising a given character or the span of characters corresponding to a given token). Currently - no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLMRoBERTa + no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLM-RoBERTa and XLNet models). The base classes :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast` diff --git a/docs/source/model_doc/luke.rst b/docs/source/model_doc/luke.rst index 34af117de9..e76f161eab 100644 --- a/docs/source/model_doc/luke.rst +++ b/docs/source/model_doc/luke.rst @@ -137,6 +137,12 @@ LukeModel .. autoclass:: transformers.LukeModel :members: forward +LukeForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LukeForMaskedLM + :members: forward + LukeForEntityClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/mluke.rst b/docs/source/model_doc/mluke.rst new file mode 100644 index 0000000000..ddca5fcec6 --- /dev/null +++ b/docs/source/model_doc/mluke.rst @@ -0,0 +1,66 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +mLUKE +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The mLUKE model was proposed in `mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models +`__ by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. It's a multilingual extension +of the `LUKE model `__ trained on the basis of XLM-RoBERTa. + +It is based on XLM-RoBERTa and adds entity embeddings, which helps improve performance on various downstream tasks +involving reasoning about entities such as named entity recognition, extractive question answering, relation +classification, cloze-style knowledge completion. + +The abstract from the paper is the following: + +*Recent studies have shown that multilingual pretrained language models can be effectively improved with cross-lingual +alignment information from Wikipedia entities. However, existing methods only exploit entity information in pretraining +and do not explicitly use entities in downstream tasks. In this study, we explore the effectiveness of leveraging +entity representations for downstream cross-lingual tasks. We train a multilingual language model with 24 languages +with entity representations and show the model consistently outperforms word-based pretrained models in various +cross-lingual transfer tasks. We also analyze the model and the key insight is that incorporating entity +representations into the input allows us to extract more language-agnostic features. We also evaluate the model with a +multilingual cloze prompt task with the mLAMA dataset. We show that entity-based prompt elicits correct factual +knowledge more likely than using only word representations.* + +One can directly plug in the weights of mLUKE into a LUKE model, like so: + +.. code-block:: + + from transformers import LukeModel + + model = LukeModel.from_pretrained('studio-ousia/mluke-base') + +Note that mLUKE has its own tokenizer, :class:`~transformers.MLukeTokenizer`. You can initialize it as follows: + +.. code-block:: + + from transformers import MLukeTokenizer + + tokenizer = MLukeTokenizer.from_pretrained('studio-ousia/mluke-base') + + +As mLUKE's architecture is equivalent to that of LUKE, one can refer to :doc:`LUKE's documentation page ` for all +tips, code examples and notebooks. + +This model was contributed by `ryo0634 `__. The original code can be found `here +`__. + +MLukeTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MLukeTokenizer + :members: __call__, save_vocabulary diff --git a/docs/source/multilingual.rst b/docs/source/multilingual.rst index d65f947ddc..6df6eaa993 100644 --- a/docs/source/multilingual.rst +++ b/docs/source/multilingual.rst @@ -17,8 +17,6 @@ Most of the models available in this library are mono-lingual models (English, C models are available and have a different mechanisms than mono-lingual models. This page details the usage of these models. -The two models that currently support multiple languages are BERT and XLM. - XLM ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -127,3 +125,17 @@ Two XLM-RoBERTa checkpoints can be used for multi-lingual tasks: - ``xlm-roberta-base`` (Masked language modeling, 100 languages) - ``xlm-roberta-large`` (Masked language modeling, 100 languages) + +mLUKE +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +mLUKE is based on XLM-RoBERTa and further trained on Wikipedia articles in 24 languages with masked language modeling +as well as masked entity prediction objective. + +The model can be used in the same way as other models solely based on word-piece inputs, but also can be used with +entity representations to achieve further performance gain, with entity-related tasks such as relation extraction, +named entity recognition and question answering (see :doc:`LUKE `). + +Currently, one mLUKE checkpoint is available: + +- ``studio-ousia/mluke-base`` (Masked language modeling + Masked entity prediction, 100 languages) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a3f03bb64b..c1f576c0e0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -245,6 +245,7 @@ _import_structure = { "models.mbart": ["MBartConfig"], "models.mbart50": [], "models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], + "models.mluke": [], "models.mmbt": ["MMBTConfig"], "models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], "models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], @@ -379,6 +380,7 @@ if is_sentencepiece_available(): _import_structure["models.marian"].append("MarianTokenizer") _import_structure["models.mbart"].append("MBartTokenizer") _import_structure["models.mbart50"].append("MBart50Tokenizer") + _import_structure["models.mluke"].append("MLukeTokenizer") _import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") @@ -1037,6 +1039,7 @@ if is_torch_available(): "LukeForEntityClassification", "LukeForEntityPairClassification", "LukeForEntitySpanClassification", + "LukeForMaskedLM", "LukeModel", "LukePreTrainedModel", ] @@ -2368,6 +2371,7 @@ if TYPE_CHECKING: from .models.m2m_100 import M2M100Tokenizer from .models.marian import MarianTokenizer from .models.mbart import MBart50Tokenizer, MBartTokenizer + from .models.mluke import MLukeTokenizer from .models.mt5 import MT5Tokenizer from .models.pegasus import PegasusTokenizer from .models.reformer import ReformerTokenizer @@ -2904,6 +2908,7 @@ if TYPE_CHECKING: LukeForEntityClassification, LukeForEntityPairClassification, LukeForEntitySpanClassification, + LukeForMaskedLM, LukeModel, LukePreTrainedModel, ) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b180e9401f..61a6bebed1 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -71,6 +71,7 @@ from . import ( mbart50, megatron_bert, megatron_gpt2, + mluke, mmbt, mobilebert, mpnet, diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index f25eb0606f..268ba9dcda 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -178,6 +178,7 @@ else: ("hubert", ("Wav2Vec2CTCTokenizer", None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("luke", ("LukeTokenizer", None)), + ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), ("canine", ("CanineTokenizer", None)), ("bertweet", ("BertweetTokenizer", None)), diff --git a/src/transformers/models/luke/__init__.py b/src/transformers/models/luke/__init__.py index 32e81ddf8b..8c6275d12a 100644 --- a/src/transformers/models/luke/__init__.py +++ b/src/transformers/models/luke/__init__.py @@ -32,6 +32,7 @@ if is_torch_available(): "LukeForEntityClassification", "LukeForEntityPairClassification", "LukeForEntitySpanClassification", + "LukeForMaskedLM", "LukeModel", "LukePreTrainedModel", ] @@ -47,6 +48,7 @@ if TYPE_CHECKING: LukeForEntityClassification, LukeForEntityPairClassification, LukeForEntitySpanClassification, + LukeForMaskedLM, LukeModel, LukePreTrainedModel, ) diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 1b2a4f6ffb..c2922935ad 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -22,7 +22,7 @@ import torch import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN +from ...activations import ACT2FN, gelu from ...file_utils import ( ModelOutput, add_start_docstrings, @@ -110,6 +110,49 @@ class BaseLukeModelOutput(BaseModelOutput): entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class LukeMaskedLMOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + The sum of masked language modeling (MLM) loss and entity prediction loss. + mlm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Masked language modeling (MLM) loss. + mep_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Masked entity prediction (MEP) loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + entity_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output + of each layer plus the initial entity embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + mlm_loss: Optional[torch.FloatTensor] = None + mep_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + entity_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class EntityClassificationOutput(ModelOutput): """ @@ -674,6 +717,38 @@ class LukePooler(nn.Module): return pooled_output +class EntityPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.entity_emb_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.entity_emb_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class EntityPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.transform = EntityPredictionHeadTransform(config) + self.decoder = nn.Linear(config.entity_emb_size, config.entity_vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.entity_vocab_size)) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + + return hidden_states + + class LukePreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1013,6 +1088,170 @@ def create_position_ids_from_input_ids(input_ids, padding_idx): return incremental_indices.long() + padding_idx +# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead +class LukeLMHead(nn.Module): + """Roberta Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + + +@add_start_docstrings( + """ + The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and + masked entity prediction. + """, + LUKE_START_DOCSTRING, +) +class LukeForMaskedLM(LukePreTrainedModel): + _keys_to_ignore_on_save = [ + r"lm_head.decoder.weight", + r"lm_head.decoder.bias", + r"entity_predictions.decoder.weight", + ] + _keys_to_ignore_on_load_missing = [ + r"position_ids", + r"lm_head.decoder.weight", + r"lm_head.decoder.bias", + r"entity_predictions.decoder.weight", + ] + + def __init__(self, config): + super().__init__(config) + + self.luke = LukeModel(config) + + self.lm_head = LukeLMHead(config) + self.entity_predictions = EntityPredictionHead(config) + + self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + super().tie_weights() + self._tie_or_clone_weights(self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings) + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=LukeMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + entity_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, + entity_position_ids=None, + labels=None, + entity_labels=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + entity_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + + Returns: + + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.luke( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + loss = None + + mlm_loss = None + logits = self.lm_head(outputs.last_hidden_state) + if labels is not None: + mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1)) + if loss is None: + loss = mlm_loss + + mep_loss = None + entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) + if entity_labels is not None: + mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1)) + if loss is None: + loss = mep_loss + else: + loss = loss + mep_loss + + if not return_dict: + output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions) + if mlm_loss is not None and mep_loss is not None: + return (loss, mlm_loss, mep_loss) + output + elif mlm_loss is not None: + return (loss, mlm_loss) + output + elif mep_loss is not None: + return (loss, mep_loss) + output + else: + return output + + return LukeMaskedLMOutput( + loss=loss, + mlm_loss=mlm_loss, + mep_loss=mep_loss, + logits=logits, + entity_logits=entity_logits, + hidden_states=outputs.hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, + ) + + @add_start_docstrings( """ The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index 13b6536574..785fdf2233 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -312,17 +312,15 @@ class LukeTokenizer(RobertaTokenizer): # Input type checking for clearer error is_valid_single_text = isinstance(text, str) is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) - assert ( - is_valid_single_text or is_valid_batch_text - ), "text input must be of type `str` (single example) or `List[str]` (batch)." + if not (is_valid_single_text or is_valid_batch_text): + raise ValueError("text input must be of type `str` (single example) or `List[str]` (batch).") is_valid_single_text_pair = isinstance(text_pair, str) is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and ( len(text_pair) == 0 or isinstance(text_pair[0], str) ) - assert ( - text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair - ), "text_pair input must be of type `str` (single example) or `List[str]` (batch)." + if not (text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair): + raise ValueError("text_pair input must be of type `str` (single example) or `List[str]` (batch).") is_batched = bool(isinstance(text, (list, tuple))) @@ -391,105 +389,6 @@ class LukeTokenizer(RobertaTokenizer): **kwargs, ) - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def encode_plus( - self, - text: Union[TextInput], - text_pair: Optional[Union[TextInput]] = None, - entity_spans: Optional[EntitySpanInput] = None, - entity_spans_pair: Optional[EntitySpanInput] = None, - entities: Optional[EntityInput] = None, - entities_pair: Optional[EntityInput] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, - max_length: Optional[int] = None, - max_entity_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: Optional[bool] = False, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs - ) -> BatchEncoding: - """ - Tokenize and prepare for the model a sequence or a pair of sequences. - - .. warning:: This method is deprecated, ``__call__`` should be used instead. - - Args: - text (:obj:`str`): - The first sequence to be encoded. Each sequence must be a string. - text_pair (:obj:`str`): - The second sequence to be encoded. Each sequence must be a string. - entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`):: - The first sequence of entity spans to be encoded. The sequence consists of tuples each with two - integers denoting character-based start and end positions of entities. If you specify - :obj:`"entity_classification"` or :obj:`"entity_pair_classification"` as the ``task`` argument in the - constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the - length of the sequence must be equal to the length of ``entities``. - entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`):: - The second sequence of entity spans to be encoded. The sequence consists of tuples each with two - integers denoting character-based start and end positions of entities. If you specify the ``task`` - argument in the constructor, this argument is ignored. If you specify ``entities_pair``, the length of - the sequence must be equal to the length of ``entities_pair``. - entities (:obj:`List[str]` `optional`):: - The first sequence of entities to be encoded. The sequence consists of strings representing entities, - i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument - is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be - equal to the length of ``entity_spans``. If you specify ``entity_spans`` without specifying this - argument, the entity sequence is automatically constructed by filling it with the [MASK] entity. - entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`):: - The second sequence of entities to be encoded. The sequence consists of strings representing entities, - i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument - is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be - equal to the length of ``entity_spans_pair``. If you specify ``entity_spans_pair`` without specifying - this argument, the entity sequence is automatically constructed by filling it with the [MASK] entity. - max_entity_length (:obj:`int`, `optional`): - The maximum length of the entity sequence. - """ - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) - - return self._encode_plus( - text=text, - text_pair=text_pair, - entity_spans=entity_spans, - entity_spans_pair=entity_spans_pair, - entities=entities, - entities_pair=entities_pair, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - max_entity_length=max_entity_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - **kwargs, - ) - def _encode_plus( self, text: Union[TextInput], @@ -571,89 +470,6 @@ class LukeTokenizer(RobertaTokenizer): verbose=verbose, ) - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def batch_encode_plus( - self, - batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], - batch_entity_spans_or_entity_spans_pairs: Optional[ - Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]] - ] = None, - batch_entities_or_entities_pairs: Optional[ - Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]] - ] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, - max_length: Optional[int] = None, - max_entity_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: Optional[bool] = False, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs - ) -> BatchEncoding: - """ - Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. - - .. warning:: - This method is deprecated, ``__call__`` should be used instead. - - - Args: - batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`): - Batch of sequences or pair of sequences to be encoded. This can be a list of string or a list of pair - of string (see details in ``encode_plus``). - batch_entity_spans_or_entity_spans_pairs (:obj:`List[List[Tuple[int, int]]]`, - :obj:`List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]`, `optional`):: - Batch of entity span sequences or pairs of entity span sequences to be encoded (see details in - ``encode_plus``). - batch_entities_or_entities_pairs (:obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, - `optional`): - Batch of entity sequences or pairs of entity sequences to be encoded (see details in ``encode_plus``). - max_entity_length (:obj:`int`, `optional`): - The maximum length of the entity sequence. - """ - - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) - - return self._batch_encode_plus( - batch_text_or_text_pairs=batch_text_or_text_pairs, - batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, - batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - max_entity_length=max_entity_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - **kwargs, - ) - def _batch_encode_plus( self, batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], @@ -713,11 +529,12 @@ class LukeTokenizer(RobertaTokenizer): entity_spans, entity_spans_pair = None, None if batch_entity_spans_or_entity_spans_pairs is not None: entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index] - if entity_spans_or_entity_spans_pairs: - if isinstance(entity_spans_or_entity_spans_pairs[0][0], int): - entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None - else: - entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs + if len(entity_spans_or_entity_spans_pairs) > 0 and isinstance( + entity_spans_or_entity_spans_pairs[0], list + ): + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs + else: + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None ( first_ids, @@ -761,6 +578,25 @@ class LukeTokenizer(RobertaTokenizer): return BatchEncoding(batch_outputs) + def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spans: Optional[EntitySpanInput]): + if not isinstance(entity_spans, list): + raise ValueError("entity_spans should be given as a list") + elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple): + raise ValueError( + "entity_spans should be given as a list of tuples " "containing the start and end character indices" + ) + + if entities is not None: + + if not isinstance(entities, list): + raise ValueError("If you specify entities, they should be given as a list") + + if len(entities) > 0 and not isinstance(entities[0], str): + raise ValueError("If you specify entities, they should be given as a list of entity names") + + if len(entities) != len(entity_spans): + raise ValueError("If you specify entities, entities and entity_spans must be the same length") + def _create_input_sequence( self, text: Union[TextInput], @@ -816,15 +652,7 @@ class LukeTokenizer(RobertaTokenizer): if entity_spans is None: first_ids = get_input_ids(text) else: - assert isinstance(entity_spans, list) and ( - len(entity_spans) == 0 or isinstance(entity_spans[0], tuple) - ), "entity_spans should be given as a list of tuples containing the start and end character indices" - assert entities is None or ( - isinstance(entities, list) and (len(entities) == 0 or isinstance(entities[0], str)) - ), "If you specify entities, they should be given as a list of entity names" - assert entities is None or len(entities) == len( - entity_spans - ), "If you specify entities, entities and entity_spans must be the same length" + self._check_entity_input_format(entities, entity_spans) first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) if entities is None: @@ -836,16 +664,7 @@ class LukeTokenizer(RobertaTokenizer): if entity_spans_pair is None: second_ids = get_input_ids(text_pair) else: - assert isinstance(entity_spans_pair, list) and ( - len(entity_spans_pair) == 0 or isinstance(entity_spans_pair[0], tuple) - ), "entity_spans_pair should be given as a list of tuples containing the start and end character indices" - assert entities_pair is None or ( - isinstance(entities_pair, list) - and (len(entities_pair) == 0 or isinstance(entities_pair[0], str)) - ), "If you specify entities_pair, they should be given as a list of entity names" - assert entities_pair is None or len(entities_pair) == len( - entity_spans_pair - ), "If you specify entities_pair, entities_pair and entity_spans_pair must be the same length" + self._check_entity_input_format(entities_pair, entity_spans_pair) second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans( text_pair, entity_spans_pair @@ -856,10 +675,11 @@ class LukeTokenizer(RobertaTokenizer): second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair] elif self.task == "entity_classification": - assert ( - isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple) - ), "Entity spans should be a list containing a single tuple containing the start and end character indices of an entity" - + if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be a list containing a single tuple " + "containing the start and end character indices of an entity" + ) first_entity_ids = [self.entity_vocab["[MASK]"]] first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) @@ -876,12 +696,16 @@ class LukeTokenizer(RobertaTokenizer): first_entity_token_spans = [(entity_token_start, entity_token_end + 2)] elif self.task == "entity_pair_classification": - assert ( + if not ( isinstance(entity_spans, list) and len(entity_spans) == 2 and isinstance(entity_spans[0], tuple) and isinstance(entity_spans[1], tuple) - ), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity" + ): + raise ValueError( + "Entity spans should be provided as a list of two tuples, " + "each tuple containing the start and end character indices of an entity" + ) head_span, tail_span = entity_spans first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]] @@ -907,9 +731,11 @@ class LukeTokenizer(RobertaTokenizer): elif self.task == "entity_span_classification": mask_entity_id = self.entity_vocab["[MASK]"] - assert isinstance(entity_spans, list) and isinstance( - entity_spans[0], tuple - ), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity" + if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be provided as a list of tuples, " + "each tuple containing the start and end character indices of an entity" + ) first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) first_entity_ids = [mask_entity_id] * len(entity_spans) @@ -1218,7 +1044,6 @@ class LukeTokenizer(RobertaTokenizer): self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) # Padding - # To do: add padding of entities if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: encoded_inputs = self.pad( encoded_inputs, @@ -1369,9 +1194,8 @@ class LukeTokenizer(RobertaTokenizer): return BatchEncoding(encoded_inputs, tensor_type=return_tensors) batch_size = len(required_input) - assert all( - len(v) == batch_size for v in encoded_inputs.values() - ), "Some items in the output dictionary have a different batch size than others." + if any(len(v) != batch_size for v in encoded_inputs.values()): + raise ValueError("Some items in the output dictionary have a different batch size than others.") if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) @@ -1487,7 +1311,9 @@ class LukeTokenizer(RobertaTokenizer): encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference if entities_provided: - encoded_inputs["entity_ids"] = encoded_inputs["entity_ids"] + [0] * entity_difference + encoded_inputs["entity_ids"] = ( + encoded_inputs["entity_ids"] + [self.entity_vocab["[PAD]"]] * entity_difference + ) encoded_inputs["entity_position_ids"] = ( encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference ) @@ -1516,7 +1342,9 @@ class LukeTokenizer(RobertaTokenizer): encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] if entities_provided: - encoded_inputs["entity_ids"] = [0] * entity_difference + encoded_inputs["entity_ids"] + encoded_inputs["entity_ids"] = [self.entity_vocab["[PAD]"]] * entity_difference + encoded_inputs[ + "entity_ids" + ] encoded_inputs["entity_position_ids"] = [ [-1] * self.max_mention_length ] * entity_difference + encoded_inputs["entity_position_ids"] diff --git a/src/transformers/models/mluke/__init__.py b/src/transformers/models/mluke/__init__.py new file mode 100644 index 0000000000..8982d219f6 --- /dev/null +++ b/src/transformers/models/mluke/__init__.py @@ -0,0 +1,38 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_sentencepiece_available + + +_import_structure = {} + + +if is_sentencepiece_available(): + _import_structure["tokenization_mluke"] = ["MLukeTokenizer"] + +if TYPE_CHECKING: + if is_sentencepiece_available(): + from .tokenization_mluke import MLukeTokenizer + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000..c75a710cee --- /dev/null +++ b/src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert mLUKE checkpoint.""" + +import argparse +import json +import os +from collections import OrderedDict + +import torch + +from transformers import LukeConfig, LukeForMaskedLM, MLukeTokenizer, XLMRobertaTokenizer +from transformers.tokenization_utils_base import AddedToken + + +@torch.no_grad() +def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size): + # Load configuration defined in the metadata file + with open(metadata_path) as metadata_file: + metadata = json.load(metadata_file) + config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"]) + + # Load in the weights from the checkpoint_path + state_dict = torch.load(checkpoint_path, map_location="cpu")["module"] + + # Load the entity vocab file + entity_vocab = load_original_entity_vocab(entity_vocab_path) + # add an entry for [MASK2] + entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1 + config.entity_vocab_size += 1 + + tokenizer = XLMRobertaTokenizer.from_pretrained(metadata["model_config"]["bert_model_name"]) + + # Add special tokens to the token vocabulary for downstream tasks + entity_token_1 = AddedToken("", lstrip=False, rstrip=False) + entity_token_2 = AddedToken("", lstrip=False, rstrip=False) + tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2])) + config.vocab_size += 2 + + print(f"Saving tokenizer to {pytorch_dump_folder_path}") + tokenizer.save_pretrained(pytorch_dump_folder_path) + with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "r") as f: + tokenizer_config = json.load(f) + tokenizer_config["tokenizer_class"] = "MLukeTokenizer" + with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "w") as f: + json.dump(tokenizer_config, f) + + with open(os.path.join(pytorch_dump_folder_path, MLukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f: + json.dump(entity_vocab, f) + + tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path) + + # Initialize the embeddings of the special tokens + ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0] + ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0] + + word_emb = state_dict["embeddings.word_embeddings.weight"] + ent_emb = word_emb[ent_init_index].unsqueeze(0) + ent2_emb = word_emb[ent2_init_index].unsqueeze(0) + state_dict["embeddings.word_embeddings.weight"] = torch.cat([word_emb, ent_emb, ent2_emb]) + # add special tokens for 'entity_predictions.bias' + for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]: + decoder_bias = state_dict[bias_name] + ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0) + ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0) + state_dict[bias_name] = torch.cat([decoder_bias, ent_decoder_bias, ent2_decoder_bias]) + + # Initialize the query layers of the entity-aware self-attention mechanism + for layer_index in range(config.num_hidden_layers): + for matrix_name in ["query.weight", "query.bias"]: + prefix = f"encoder.layer.{layer_index}.attention.self." + state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name] + state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name] + state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name] + + # Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks + entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"] + entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0) + state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat([entity_emb, entity_mask_emb]) + # add [MASK2] for 'entity_predictions.bias' + entity_prediction_bias = state_dict["entity_predictions.bias"] + entity_mask_bias = entity_prediction_bias[entity_vocab["[MASK]"]].unsqueeze(0) + state_dict["entity_predictions.bias"] = torch.cat([entity_prediction_bias, entity_mask_bias]) + + model = LukeForMaskedLM(config=config).eval() + + state_dict.pop("entity_predictions.decoder.weight") + state_dict.pop("lm_head.decoder.weight") + state_dict.pop("lm_head.decoder.bias") + state_dict_for_hugging_face = OrderedDict() + for key, value in state_dict.items(): + if not (key.startswith("lm_head") or key.startswith("entity_predictions")): + state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key] + else: + state_dict_for_hugging_face[key] = state_dict[key] + + missing_keys, unexpected_keys = model.load_state_dict(state_dict_for_hugging_face, strict=False) + + if set(unexpected_keys) != {"luke.embeddings.position_ids"}: + raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}") + if set(missing_keys) != { + "lm_head.decoder.weight", + "lm_head.decoder.bias", + "entity_predictions.decoder.weight", + }: + raise ValueError(f"Unexpected missing_keys: {missing_keys}") + + model.tie_weights() + assert (model.luke.embeddings.word_embeddings.weight == model.lm_head.decoder.weight).all() + assert (model.luke.entity_embeddings.entity_embeddings.weight == model.entity_predictions.decoder.weight).all() + + # Check outputs + tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification") + + text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + span = (0, 9) + encoding = tokenizer(text, entity_spans=[span], return_tensors="pt") + + outputs = model(**encoding) + + # Verify word hidden states + if model_size == "large": + raise NotImplementedError + else: # base + expected_shape = torch.Size((1, 33, 768)) + expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819], [0.0134, 0.1199, 0.0573], [-0.0169, 0.0927, 0.0644]]) + + if not (outputs.last_hidden_state.shape == expected_shape): + raise ValueError( + f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}" + ) + if not torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4): + raise ValueError + + # Verify entity hidden states + if model_size == "large": + raise NotImplementedError + else: # base + expected_shape = torch.Size((1, 1, 768)) + expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]]) + + if not (outputs.entity_last_hidden_state.shape == expected_shape): + raise ValueError( + f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}" + ) + if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4): + raise ValueError + + # Verify masked word/entity prediction + tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path) + text = "Tokyo is the capital of ." + span = (24, 30) + encoding = tokenizer(text, entity_spans=[span], return_tensors="pt") + + outputs = model(**encoding) + + input_ids = encoding["input_ids"][0].tolist() + mask_position_id = input_ids.index(tokenizer.convert_tokens_to_ids("")) + predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1) + assert "Japan" == tokenizer.decode(predicted_id) + + predicted_entity_id = outputs.entity_logits[0][0].argmax().item() + multilingual_predicted_entities = [ + entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id == predicted_entity_id + ] + assert [e for e in multilingual_predicted_entities if e.startswith("en:")][0] == "en:Japan" + + # Finally, save our PyTorch model and tokenizer + print("Saving PyTorch model to {}".format(pytorch_dump_folder_path)) + model.save_pretrained(pytorch_dump_folder_path) + + +def load_original_entity_vocab(entity_vocab_path): + SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]"] + + data = [json.loads(line) for line in open(entity_vocab_path)] + + new_mapping = {} + for entry in data: + entity_id = entry["id"] + for entity_name, language in entry["entities"]: + if entity_name in SPECIAL_TOKENS: + new_mapping[entity_name] = entity_id + break + new_entity_name = f"{language}:{entity_name}" + new_mapping[new_entity_name] = entity_id + return new_mapping + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--checkpoint_path", type=str, help="Path to a pytorch_model.bin file.") + parser.add_argument( + "--metadata_path", default=None, type=str, help="Path to a metadata.json file, defining the configuration." + ) + parser.add_argument( + "--entity_vocab_path", + default=None, + type=str, + help="Path to an entity_vocab.tsv file, containing the entity vocabulary.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to where to dump the output PyTorch model." + ) + parser.add_argument( + "--model_size", default="base", type=str, choices=["base", "large"], help="Size of the model to be converted." + ) + args = parser.parse_args() + convert_luke_checkpoint( + args.checkpoint_path, + args.metadata_path, + args.entity_vocab_path, + args.pytorch_dump_folder_path, + args.model_size, + ) diff --git a/src/transformers/models/mluke/tokenization_mluke.py b/src/transformers/models/mluke/tokenization_mluke.py new file mode 100644 index 0000000000..aa547737c7 --- /dev/null +++ b/src/transformers/models/mluke/tokenization_mluke.py @@ -0,0 +1,1606 @@ +# coding=utf-8 +# Copyright 2021 Studio Ousia and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +""" Tokenization classes for mLUKE.""" + + +import itertools +import json +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +import sentencepiece as spm + +from ...file_utils import add_end_docstrings, is_tf_available, is_torch_available +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import ( + ENCODE_KWARGS_DOCSTRING, + AddedToken, + BatchEncoding, + EncodedInput, + PaddingStrategy, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, + _is_tensorflow, + _is_torch, + to_py_obj, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) + +EntitySpan = Tuple[int, int] +EntitySpanInput = List[EntitySpan] +Entity = str +EntityInput = List[Entity] + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "entity_vocab_file": "entity_vocab.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/vocab.json", + }, + "merges_file": { + "studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/merges.txt", + }, + "entity_vocab_file": { + "studio-ousia/mluke-base": "https://huggingface.co/studio-ousia/mluke-base/resolve/main/entity_vocab.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "studio-ousia/mluke-base": 512, +} + +ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + return_token_type_ids (:obj:`bool`, `optional`): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + return_attention_mask (:obj:`bool`, `optional`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. + + `What are attention masks? <../glossary.html#attention-mask>`__ + return_overflowing_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with :obj:`truncation_strategy = longest_first` or :obj:`True`, an error is + raised instead of returning overflowing tokens. + return_special_tokens_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return :obj:`(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from + :class:`~transformers.PreTrainedTokenizerFast`, if using Python's tokenizer, this method will raise + :obj:`NotImplementedError`. + return_length (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return the lengths of the encoded inputs. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. + **kwargs: passed to the :obj:`self.tokenize()` method + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + + - **token_type_ids** -- List of token type ids to be fed to a model (when :obj:`return_token_type_ids=True` + or if `"token_type_ids"` is in :obj:`self.model_input_names`). + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + - **entity_ids** -- List of entity ids to be fed to a model. + + `What are input IDs? <../glossary.html#input-ids>`__ + + - **entity_position_ids** -- List of entity positions in the input sequence to be fed to a model. + + - **entity_token_type_ids** -- List of entity token type ids to be fed to a model (when + :obj:`return_token_type_ids=True` or if `"entity_token_type_ids"` is in :obj:`self.model_input_names`). + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + + - **entity_attention_mask** -- List of indices specifying which entities should be attended to by the model + (when :obj:`return_attention_mask=True` or if `"entity_attention_mask"` is in + :obj:`self.model_input_names`). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + - **entity_start_positions** -- List of the start positions of entities in the word token sequence (when + :obj:`task="entity_span_classification"`). + - **entity_end_positions** -- List of the end positions of entities in the word token sequence (when + :obj:`task="entity_span_classification"`). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a :obj:`max_length` is specified and + :obj:`return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a :obj:`max_length` is specified and + :obj:`return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when :obj:`add_special_tokens=True` and :obj:`return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when :obj:`return_length=True`) + +""" + + +class MLukeTokenizer(PreTrainedTokenizer): + """ + Adapted from :class:`~transformers.XLMRobertaTokenizer` and :class:`~transformers.LukeTokenizer`. Based on + `SentencePiece `__. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + entity_vocab_file (:obj:`str`): + Path to the entity vocabulary file. + bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the :obj:`cls_token`. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sequence token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the end of + sequence. The token used is the :obj:`sep_token`. + sep_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + task (:obj:`str`, `optional`): + Task for which you want to prepare sequences. One of :obj:`"entity_classification"`, + :obj:`"entity_pair_classification"`, or :obj:`"entity_span_classification"`. If you specify this argument, + the entity sequence is automatically created based on the given entity span(s). + max_entity_length (:obj:`int`, `optional`, defaults to 32): + The maximum length of :obj:`entity_ids`. + max_mention_length (:obj:`int`, `optional`, defaults to 30): + The maximum number of tokens inside an entity span. + entity_token_1 (:obj:`str`, `optional`, defaults to :obj:``): + The special token used to represent an entity span in a word token sequence. This token is only used when + ``task`` is set to :obj:`"entity_classification"` or :obj:`"entity_pair_classification"`. + entity_token_2 (:obj:`str`, `optional`, defaults to :obj:``): + The special token used to represent an entity span in a word token sequence. This token is only used when + ``task`` is set to :obj:`"entity_pair_classification"`. + additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + sp_model_kwargs (:obj:`dict`, `optional`): + Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece + `__ can be used, among other things, to set: + + - ``enable_sampling``: Enable subword regularization. + - ``nbest_size``: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - ``nbest_size = {0,1}``: No sampling is performed. + - ``nbest_size > 1``: samples from the nbest_size results. + - ``nbest_size < 0``: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - ``alpha``: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (:obj:`SentencePieceProcessor`): + The `SentencePiece` processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + entity_vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + task=None, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + # we add 2 special tokens for downstream tasks + # for more information about lstrip and rstrip, see https://github.com/huggingface/transformers/pull/2778 + entity_token_1 = ( + AddedToken(entity_token_1, lstrip=False, rstrip=False) + if isinstance(entity_token_1, str) + else entity_token_1 + ) + entity_token_2 = ( + AddedToken(entity_token_2, lstrip=False, rstrip=False) + if isinstance(entity_token_2, str) + else entity_token_2 + ) + kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) + kwargs["additional_special_tokens"] += [entity_token_1, entity_token_2] + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + task=task, + max_entity_length=max_entity_length, + max_mention_length=max_mention_length, + entity_token_1=entity_token_1, + entity_token_2=entity_token_2, + **kwargs, + ) + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + self.fairseq_offset + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + + with open(entity_vocab_file, encoding="utf-8") as entity_vocab_handle: + self.entity_vocab = json.load(entity_vocab_handle) + + self.task = task + if task is None or task == "entity_span_classification": + self.max_entity_length = max_entity_length + elif task == "entity_classification": + self.max_entity_length = 1 + elif task == "entity_pair_classification": + self.max_entity_length = 2 + else: + raise ValueError( + f"Task {task} not supported. Select task from ['entity_classification', 'entity_pair_classification', 'entity_span_classification'] only." + ) + + self.max_mention_length = max_mention_length + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.__call__ + def __call__( + self, + text: Union[TextInput, List[TextInput]], + text_pair: Optional[Union[TextInput, List[TextInput]]] = None, + entity_spans: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None, + entity_spans_pair: Optional[Union[EntitySpanInput, List[EntitySpanInput]]] = None, + entities: Optional[Union[EntityInput, List[EntityInput]]] = None, + entities_pair: Optional[Union[EntityInput, List[EntityInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences, depending on the task you want to prepare them for. + + Args: + text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized strings. + entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify + :obj:`"entity_classification"` or :obj:`"entity_pair_classification"` as the ``task`` argument in the + constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the + length of each sequence must be equal to the length of each sequence of ``entities``. + entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`): + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting character-based start and end positions of entities. If you specify the + ``task`` argument in the constructor, this argument is ignored. If you specify ``entities_pair``, the + length of each sequence must be equal to the length of each sequence of ``entities_pair``. + entities (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the ``task`` argument in the constructor. The length + of each sequence must be equal to the length of each sequence of ``entity_spans``. If you specify + ``entity_spans`` without specifying this argument, the entity sequence or the batch of entity sequences + is automatically constructed by filling it with the [MASK] entity. + entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los + Angeles). This argument is ignored if you specify the ``task`` argument in the constructor. The length + of each sequence must be equal to the length of each sequence of ``entity_spans_pair``. If you specify + ``entity_spans_pair`` without specifying this argument, the entity sequence or the batch of entity + sequences is automatically constructed by filling it with the [MASK] entity. + max_entity_length (:obj:`int`, `optional`): + The maximum length of :obj:`entity_ids`. + """ + # Input type checking for clearer error + is_valid_single_text = isinstance(text, str) + is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str))) + if not (is_valid_single_text or is_valid_batch_text): + raise ValueError("text input must be of type `str` (single example) or `List[str]` (batch).") + + is_valid_single_text_pair = isinstance(text_pair, str) + is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and ( + len(text_pair) == 0 or isinstance(text_pair[0], str) + ) + if not (text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair): + raise ValueError("text_pair input must be of type `str` (single example) or `List[str]` (batch).") + + is_batched = bool(isinstance(text, (list, tuple))) + + if is_batched: + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + if entities is None: + batch_entities_or_entities_pairs = None + else: + batch_entities_or_entities_pairs = ( + list(zip(entities, entities_pair)) if entities_pair is not None else entities + ) + + if entity_spans is None: + batch_entity_spans_or_entity_spans_pairs = None + else: + batch_entity_spans_or_entity_spans_pairs = ( + list(zip(entity_spans, entity_spans_pair)) if entity_spans_pair is not None else entity_spans + ) + + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, + batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + entities=entities, + entities_pair=entities_pair, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._encode_plus + def _encode_plus( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + + # prepare_for_model will create the attention_mask and token_type_ids + return self.prepare_for_model( + first_ids, + pair_ids=second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._batch_encode_plus + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]], + batch_entity_spans_or_entity_spans_pairs: Optional[ + Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]] + ] = None, + batch_entities_or_entities_pairs: Optional[ + Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]] + ] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: Optional[bool] = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if is_split_into_words: + raise NotImplementedError("is_split_into_words is not supported in this tokenizer.") + + # input_ids is a list of tuples (one for each example in the batch) + input_ids = [] + entity_ids = [] + entity_token_spans = [] + for index, text_or_text_pair in enumerate(batch_text_or_text_pairs): + if not isinstance(text_or_text_pair, (list, tuple)): + text, text_pair = text_or_text_pair, None + else: + text, text_pair = text_or_text_pair + + entities, entities_pair = None, None + if batch_entities_or_entities_pairs is not None: + entities_or_entities_pairs = batch_entities_or_entities_pairs[index] + if entities_or_entities_pairs: + if isinstance(entities_or_entities_pairs[0], str): + entities, entities_pair = entities_or_entities_pairs, None + else: + entities, entities_pair = entities_or_entities_pairs + + entity_spans, entity_spans_pair = None, None + if batch_entity_spans_or_entity_spans_pairs is not None: + entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index] + if len(entity_spans_or_entity_spans_pairs) > 0 and isinstance( + entity_spans_or_entity_spans_pairs[0], list + ): + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs + else: + entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None + + ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) = self._create_input_sequence( + text=text, + text_pair=text_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=entity_spans, + entity_spans_pair=entity_spans_pair, + **kwargs, + ) + input_ids.append((first_ids, second_ids)) + entity_ids.append((first_entity_ids, second_entity_ids)) + entity_token_spans.append((first_entity_token_spans, second_entity_token_spans)) + + batch_outputs = self._batch_prepare_for_model( + input_ids, + batch_entity_ids_pairs=entity_ids, + batch_entity_token_spans_pairs=entity_token_spans, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._check_entity_input_format + def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spans: Optional[EntitySpanInput]): + if not isinstance(entity_spans, list): + raise ValueError("entity_spans should be given as a list") + elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple): + raise ValueError( + "entity_spans should be given as a list of tuples " "containing the start and end character indices" + ) + + if entities is not None: + + if not isinstance(entities, list): + raise ValueError("If you specify entities, they should be given as a list") + + if len(entities) > 0 and not isinstance(entities[0], str): + raise ValueError("If you specify entities, they should be given as a list of entity names") + + if len(entities) != len(entity_spans): + raise ValueError("If you specify entities, entities and entity_spans must be the same length") + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._create_input_sequence + def _create_input_sequence( + self, + text: Union[TextInput], + text_pair: Optional[Union[TextInput]] = None, + entities: Optional[EntityInput] = None, + entities_pair: Optional[EntityInput] = None, + entity_spans: Optional[EntitySpanInput] = None, + entity_spans_pair: Optional[EntitySpanInput] = None, + **kwargs + ) -> Tuple[list, list, list, list, list, list]: + def get_input_ids(text): + tokens = self.tokenize(text, **kwargs) + return self.convert_tokens_to_ids(tokens) + + def get_input_ids_and_entity_token_spans(text, entity_spans): + if entity_spans is None: + return get_input_ids(text), None + + cur = 0 + input_ids = [] + entity_token_spans = [None] * len(entity_spans) + + split_char_positions = sorted(frozenset(itertools.chain(*entity_spans))) + char_pos2token_pos = {} + + for split_char_position in split_char_positions: + orig_split_char_position = split_char_position + if ( + split_char_position > 0 and text[split_char_position - 1] == " " + ): # whitespace should be prepended to the following token + split_char_position -= 1 + if cur != split_char_position: + input_ids += get_input_ids(text[cur:split_char_position]) + cur = split_char_position + char_pos2token_pos[orig_split_char_position] = len(input_ids) + + input_ids += get_input_ids(text[cur:]) + + entity_token_spans = [ + (char_pos2token_pos[char_start], char_pos2token_pos[char_end]) for char_start, char_end in entity_spans + ] + + return input_ids, entity_token_spans + + first_ids, second_ids = None, None + first_entity_ids, second_entity_ids = None, None + first_entity_token_spans, second_entity_token_spans = None, None + + if self.task is None: + unk_entity_id = self.entity_vocab["[UNK]"] + mask_entity_id = self.entity_vocab["[MASK]"] + + if entity_spans is None: + first_ids = get_input_ids(text) + else: + self._check_entity_input_format(entities, entity_spans) + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + if entities is None: + first_entity_ids = [mask_entity_id] * len(entity_spans) + else: + first_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities] + + if text_pair is not None: + if entity_spans_pair is None: + second_ids = get_input_ids(text_pair) + else: + self._check_entity_input_format(entities_pair, entity_spans_pair) + + second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans( + text_pair, entity_spans_pair + ) + if entities_pair is None: + second_entity_ids = [mask_entity_id] * len(entity_spans_pair) + else: + second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair] + + elif self.task == "entity_classification": + if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be a list containing a single tuple " + "containing the start and end character indices of an entity" + ) + first_entity_ids = [self.entity_vocab["[MASK]"]] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + # add special tokens to input ids + entity_token_start, entity_token_end = first_entity_token_spans[0] + first_ids = ( + first_ids[:entity_token_end] + [self.additional_special_tokens_ids[0]] + first_ids[entity_token_end:] + ) + first_ids = ( + first_ids[:entity_token_start] + + [self.additional_special_tokens_ids[0]] + + first_ids[entity_token_start:] + ) + first_entity_token_spans = [(entity_token_start, entity_token_end + 2)] + + elif self.task == "entity_pair_classification": + if not ( + isinstance(entity_spans, list) + and len(entity_spans) == 2 + and isinstance(entity_spans[0], tuple) + and isinstance(entity_spans[1], tuple) + ): + raise ValueError( + "Entity spans should be provided as a list of two tuples, " + "each tuple containing the start and end character indices of an entity" + ) + + head_span, tail_span = entity_spans + first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]] + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + + head_token_span, tail_token_span = first_entity_token_spans + token_span_with_special_token_ids = [ + (head_token_span, self.additional_special_tokens_ids[0]), + (tail_token_span, self.additional_special_tokens_ids[1]), + ] + if head_token_span[0] < tail_token_span[0]: + first_entity_token_spans[0] = (head_token_span[0], head_token_span[1] + 2) + first_entity_token_spans[1] = (tail_token_span[0] + 2, tail_token_span[1] + 4) + token_span_with_special_token_ids = reversed(token_span_with_special_token_ids) + else: + first_entity_token_spans[0] = (head_token_span[0] + 2, head_token_span[1] + 4) + first_entity_token_spans[1] = (tail_token_span[0], tail_token_span[1] + 2) + + for (entity_token_start, entity_token_end), special_token_id in token_span_with_special_token_ids: + first_ids = first_ids[:entity_token_end] + [special_token_id] + first_ids[entity_token_end:] + first_ids = first_ids[:entity_token_start] + [special_token_id] + first_ids[entity_token_start:] + + elif self.task == "entity_span_classification": + mask_entity_id = self.entity_vocab["[MASK]"] + + if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)): + raise ValueError( + "Entity spans should be provided as a list of tuples, " + "each tuple containing the start and end character indices of an entity" + ) + + first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans) + first_entity_ids = [mask_entity_id] * len(entity_spans) + + else: + raise ValueError(f"Task {self.task} not supported") + + return ( + first_ids, + second_ids, + first_entity_ids, + second_entity_ids, + first_entity_token_spans, + second_entity_token_spans, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._batch_prepare_for_model + def _batch_prepare_for_model( + self, + batch_ids_pairs: List[Tuple[List[int], None]], + batch_entity_ids_pairs: List[Tuple[Optional[List[int]], Optional[List[int]]]], + batch_entity_token_spans_pairs: List[Tuple[Optional[List[Tuple[int, int]]], Optional[List[Tuple[int, int]]]]], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + batch_entity_ids_pairs: list of entity ids or entity ids pairs + batch_entity_token_spans_pairs: list of entity spans or entity spans pairs + max_entity_length: The maximum length of the entity sequence. + """ + + batch_outputs = {} + for input_ids, entity_ids, entity_token_span_pairs in zip( + batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs + ): + first_ids, second_ids = input_ids + first_entity_ids, second_entity_ids = entity_ids + first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs + outputs = self.prepare_for_model( + first_ids, + second_ids, + entity_ids=first_entity_ids, + pair_entity_ids=second_entity_ids, + entity_token_spans=first_entity_token_spans, + pair_entity_token_spans=second_entity_token_spans, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + max_entity_length=max_entity_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.prepare_for_model + def prepare_for_model( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + entity_ids: Optional[List[int]] = None, + pair_entity_ids: Optional[List[int]] = None, + entity_token_spans: Optional[List[Tuple[int, int]]] = None, + pair_entity_token_spans: Optional[List[Tuple[int, int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + """ + Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids, + entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing + while taking into account the special tokens and manages a moving window (with user defined stride) for + overflowing tokens. Please Note, for `pair_ids` different than `None` and `truncation_strategy = longest_first` + or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an + error. + + Args: + ids (:obj:`List[int]`): + Tokenized input ids of the first sequence. + pair_ids (:obj:`List[int]`, `optional`): + Tokenized input ids of the second sequence. + entity_ids (:obj:`List[int]`, `optional`): + Entity ids of the first sequence. + pair_entity_ids (:obj:`List[int]`, `optional`): + Entity ids of the second sequence. + entity_token_spans (:obj:`List[Tuple[int, int]]`, `optional`): + Entity spans of the first sequence. + pair_entity_token_spans (:obj:`List[Tuple[int, int]]`, `optional`): + Entity spans of the second sequence. + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + # Compute lengths + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned word encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length and max_entity_length + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + # truncate words up to max_length + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + entity_token_offset = 1 # 1 * token + pair_entity_token_offset = len(ids) + 3 # 1 * token & 2 * tokens + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + entity_token_offset = 0 + pair_entity_token_offset = len(ids) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Set max entity length + if not max_entity_length: + max_entity_length = self.max_entity_length + + if entity_ids is not None: + total_entity_len = 0 + num_invalid_entities = 0 + valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)] + valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)] + + total_entity_len += len(valid_entity_ids) + num_invalid_entities += len(entity_ids) - len(valid_entity_ids) + + valid_pair_entity_ids, valid_pair_entity_token_spans = None, None + if pair_entity_ids is not None: + valid_pair_entity_ids = [ + ent_id + for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans) + if span[1] <= len(pair_ids) + ] + valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)] + total_entity_len += len(valid_pair_entity_ids) + num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids) + + if num_invalid_entities != 0: + logger.warning( + f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the truncation of input tokens" + ) + + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length: + # truncate entities up to max_entity_length + valid_entity_ids, valid_pair_entity_ids, overflowing_entities = self.truncate_sequences( + valid_entity_ids, + pair_ids=valid_pair_entity_ids, + num_tokens_to_remove=total_entity_len - max_entity_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + valid_entity_token_spans = valid_entity_token_spans[: len(valid_entity_ids)] + if valid_pair_entity_token_spans is not None: + valid_pair_entity_token_spans = valid_pair_entity_token_spans[: len(valid_pair_entity_ids)] + + if return_overflowing_tokens: + encoded_inputs["overflowing_entities"] = overflowing_entities + encoded_inputs["num_truncated_entities"] = total_entity_len - max_entity_length + + final_entity_ids = valid_entity_ids + valid_pair_entity_ids if valid_pair_entity_ids else valid_entity_ids + encoded_inputs["entity_ids"] = list(final_entity_ids) + entity_position_ids = [] + entity_start_positions = [] + entity_end_positions = [] + for (token_spans, offset) in ( + (valid_entity_token_spans, entity_token_offset), + (valid_pair_entity_token_spans, pair_entity_token_offset), + ): + if token_spans is not None: + for start, end in token_spans: + start += offset + end += offset + position_ids = list(range(start, end))[: self.max_mention_length] + position_ids += [-1] * (self.max_mention_length - end + start) + entity_position_ids.append(position_ids) + entity_start_positions.append(start) + entity_end_positions.append(end - 1) + + encoded_inputs["entity_position_ids"] = entity_position_ids + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = entity_start_positions + encoded_inputs["entity_end_positions"] = entity_end_positions + + if return_token_type_ids: + encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"]) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer.pad + def pad( + self, + encoded_inputs: Union[ + BatchEncoding, + List[BatchEncoding], + Dict[str, EncodedInput], + Dict[str, List[EncodedInput]], + List[Dict[str, EncodedInput]], + ], + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + ) -> BatchEncoding: + """ + Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length + in the batch. Padding side (left/right) padding token ids are defined at the tokenizer level (with + ``self.padding_side``, ``self.pad_token_id`` and ``self.pad_token_type_id``) .. note:: If the + ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the result + will use the same type unless you provide a different tensor type with ``return_tensors``. In the case of + PyTorch tensors, you will lose the specific device of your tensors however. + + Args: + encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`): + Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str, + List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str, + List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as + well as in a PyTorch Dataloader collate function. Instead of :obj:`List[int]` you can have tensors + (numpy arrays, PyTorch tensors or TensorFlow tensors), see the note above for the return type. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_attention_mask (:obj:`bool`, `optional`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute. `What are + attention masks? <../glossary.html#attention-mask>`__ + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. + """ + # If we have a list of dicts, let's convert it in a dict of lists + # We do this to allow using this method as a collate_fn function in PyTorch Dataloader + if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)): + encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} + + # The model's main input name, usually `input_ids`, has be passed for padding + if self.model_input_names[0] not in encoded_inputs: + raise ValueError( + "You should supply an encoding or a list of encodings to this method " + f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" + ) + + required_input = encoded_inputs[self.model_input_names[0]] + + if not required_input: + if return_attention_mask: + encoded_inputs["attention_mask"] = [] + return encoded_inputs + + # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects + # and rebuild them afterwards if no return_tensors is specified + # Note that we lose the specific device the tensor may be on for PyTorch + + first_element = required_input[0] + if isinstance(first_element, (list, tuple)): + # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. + index = 0 + while len(required_input[index]) == 0: + index += 1 + if index < len(required_input): + first_element = required_input[index][0] + # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. + if not isinstance(first_element, (int, list, tuple)): + if is_tf_available() and _is_tensorflow(first_element): + return_tensors = "tf" if return_tensors is None else return_tensors + elif is_torch_available() and _is_torch(first_element): + return_tensors = "pt" if return_tensors is None else return_tensors + elif isinstance(first_element, np.ndarray): + return_tensors = "np" if return_tensors is None else return_tensors + else: + raise ValueError( + f"type of {first_element} unknown: {type(first_element)}. " + f"Should be one of a python, numpy, pytorch or tensorflow object." + ) + + for key, value in encoded_inputs.items(): + encoded_inputs[key] = to_py_obj(value) + + # Convert padding_strategy in PaddingStrategy + padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( + padding=padding, max_length=max_length, verbose=verbose + ) + + if max_entity_length is None: + max_entity_length = self.max_entity_length + + required_input = encoded_inputs[self.model_input_names[0]] + if required_input and not isinstance(required_input[0], (list, tuple)): + encoded_inputs = self._pad( + encoded_inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + return BatchEncoding(encoded_inputs, tensor_type=return_tensors) + + batch_size = len(required_input) + if any(len(v) != batch_size for v in encoded_inputs.values()): + raise ValueError("Some items in the output dictionary have a different batch size than others.") + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(len(inputs) for inputs in required_input) + max_entity_length = ( + max(len(inputs) for inputs in encoded_inputs["entity_ids"]) if "entity_ids" in encoded_inputs else 0 + ) + padding_strategy = PaddingStrategy.MAX_LENGTH + + batch_outputs = {} + for i in range(batch_size): + inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) + outputs = self._pad( + inputs, + max_length=max_length, + max_entity_length=max_entity_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + return BatchEncoding(batch_outputs, tensor_type=return_tensors) + + # Copied from transformers.models.luke.tokenization_luke.LukeTokenizer._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + + Args: + encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + max_entity_length: The maximum length of the entity sequence. + padding_strategy: PaddingStrategy to use for padding. + + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + entities_provided = bool("entity_ids" in encoded_inputs) + + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) + if entities_provided: + max_entity_length = len(encoded_inputs["entity_ids"]) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if ( + entities_provided + and max_entity_length is not None + and pad_to_multiple_of is not None + and (max_entity_length % pad_to_multiple_of != 0) + ): + max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ( + len(encoded_inputs["input_ids"]) != max_length + or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length) + ) + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + if entities_provided and return_attention_mask and "entity_attention_mask" not in encoded_inputs: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + if entities_provided: + entity_difference = max_entity_length - len(encoded_inputs["entity_ids"]) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if entities_provided: + encoded_inputs["entity_attention_mask"] = ( + encoded_inputs["entity_attention_mask"] + [0] * entity_difference + ) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference + if entities_provided: + encoded_inputs["entity_token_type_ids"] = ( + encoded_inputs["entity_token_type_ids"] + [0] * entity_difference + ) + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + if entities_provided: + encoded_inputs["entity_ids"] = ( + encoded_inputs["entity_ids"] + [self.entity_vocab["[PAD]"]] * entity_difference + ) + encoded_inputs["entity_position_ids"] = ( + encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference + ) + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = ( + encoded_inputs["entity_start_positions"] + [0] * entity_difference + ) + encoded_inputs["entity_end_positions"] = ( + encoded_inputs["entity_end_positions"] + [0] * entity_difference + ) + + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if entities_provided: + encoded_inputs["entity_attention_mask"] = [0] * entity_difference + encoded_inputs[ + "entity_attention_mask" + ] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"] + if entities_provided: + encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[ + "entity_token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] + if entities_provided: + encoded_inputs["entity_ids"] = [self.entity_vocab["[PAD]"]] * entity_difference + encoded_inputs[ + "entity_ids" + ] + encoded_inputs["entity_position_ids"] = [ + [-1] * self.max_mention_length + ] * entity_difference + encoded_inputs["entity_position_ids"] + if self.task == "entity_span_classification": + encoded_inputs["entity_start_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_start_positions" + ] + encoded_inputs["entity_end_positions"] = [0] * entity_difference + encoded_inputs[ + "entity_end_positions" + ] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + entity_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"] + ) + + with open(entity_vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.entity_vocab, ensure_ascii=False)) + + return out_vocab_file, entity_vocab_file + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM-RoBERTa sequence has the following format: + + - single sequence: `` X `` + - pair of sequences: `` A B `` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.vocab_size + def vocab_size(self): + return len(self.sp_model) + self.fairseq_offset + 1 # Add the token + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer._tokenize + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + # Copied from transformers.models.xlm_roberta.tokenization_xlm_roberta.XLMRobertaTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 685b5b9944..3ab5f6a416 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3020,6 +3020,18 @@ class LukeForEntitySpanClassification: requires_backends(self, ["torch"]) +class LukeForMaskedLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LukeModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index ba59da4c0b..9bdc03411a 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -110,6 +110,15 @@ class MBartTokenizer: requires_backends(cls, ["sentencepiece"]) +class MLukeTokenizer: + def __init__(self, *args, **kwargs): + requires_backends(self, ["sentencepiece"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["sentencepiece"]) + + class MT5Tokenizer: def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece"]) diff --git a/tests/test_modeling_luke.py b/tests/test_modeling_luke.py index 99ef4686ad..94b9fe7695 100644 --- a/tests/test_modeling_luke.py +++ b/tests/test_modeling_luke.py @@ -29,6 +29,7 @@ if is_torch_available(): LukeForEntityClassification, LukeForEntityPairClassification, LukeForEntitySpanClassification, + LukeForMaskedLM, LukeModel, LukeTokenizer, ) @@ -138,12 +139,17 @@ class LukeModelTester: ) sequence_labels = None + labels = None + entity_labels = None entity_classification_labels = None entity_pair_classification_labels = None entity_span_classification_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size) + entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels) entity_pair_classification_labels = ids_tensor( [self.batch_size], self.num_entity_pair_classification_labels @@ -164,6 +170,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, + labels, + entity_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, @@ -199,6 +207,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, + labels, + entity_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, @@ -226,6 +236,44 @@ class LukeModelTester: result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_for_masked_lm( + self, + config, + input_ids, + attention_mask, + token_type_ids, + entity_ids, + entity_attention_mask, + entity_token_type_ids, + entity_position_ids, + sequence_labels, + labels, + entity_labels, + entity_classification_labels, + entity_pair_classification_labels, + entity_span_classification_labels, + ): + config.num_labels = self.num_entity_classification_labels + model = LukeForMaskedLM(config) + model.to(torch_device) + model.eval() + + result = model( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + entity_ids=entity_ids, + entity_attention_mask=entity_attention_mask, + entity_token_type_ids=entity_token_type_ids, + entity_position_ids=entity_position_ids, + labels=labels, + entity_labels=entity_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual( + result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size) + ) + def create_and_check_for_entity_classification( self, config, @@ -237,6 +285,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, + labels, + entity_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, @@ -269,6 +319,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, + labels, + entity_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, @@ -301,6 +353,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, + labels, + entity_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, @@ -341,6 +395,8 @@ class LukeModelTester: entity_token_type_ids, entity_position_ids, sequence_labels, + labels, + entity_labels, entity_classification_labels, entity_pair_classification_labels, entity_span_classification_labels, @@ -363,6 +419,7 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( LukeModel, + LukeForMaskedLM, LukeForEntityClassification, LukeForEntityPairClassification, LukeForEntitySpanClassification, @@ -396,6 +453,18 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): dtype=torch.long, device=torch_device, ) + elif model_class == LukeForMaskedLM: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), + dtype=torch.long, + device=torch_device, + ) + inputs_dict["entity_labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.entity_length), + dtype=torch.long, + device=torch_device, + ) + return inputs_dict def setUp(self): @@ -415,6 +484,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): model = LukeModel.from_pretrained(model_name) self.assertIsNotNone(model) + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + def test_for_entity_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_entity_classification(*config_and_inputs) diff --git a/tests/test_tokenization_luke.py b/tests/test_tokenization_luke.py index 148e7de27b..a869eadfbe 100644 --- a/tests/test_tokenization_luke.py +++ b/tests/test_tokenization_luke.py @@ -23,7 +23,7 @@ from transformers.testing_utils import require_torch, slow from .test_tokenization_common import TokenizerTesterMixin -class Luke(TokenizerTesterMixin, unittest.TestCase): +class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = LukeTokenizer test_rust_tokenizer = False from_pretrained_kwargs = {"cls_token": ""} @@ -79,8 +79,8 @@ class Luke(TokenizerTesterMixin, unittest.TestCase): encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) - assert encoded_sentence == encoded_text_from_decode - assert encoded_pair == encoded_pair_from_decode + self.assertEqual(encoded_sentence, encoded_text_from_decode) + self.assertEqual(encoded_pair, encoded_pair_from_decode) def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]: txt = "Beyonce lives in Los Angeles" @@ -159,6 +159,81 @@ class Luke(TokenizerTesterMixin, unittest.TestCase): tokens_p_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""] ) + def test_padding_entity_inputs(self): + tokenizer = self.get_tokenizer() + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + span = (15, 34) + pad_id = tokenizer.entity_vocab["[PAD]"] + mask_id = tokenizer.entity_vocab["[MASK]"] + + encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True) + self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]]) + + # test with a sentence with no entity + encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True) + self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]]) + + def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer() + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + spans = [(15, 34)] + entities = ["East Asian language"] + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=tuple(entities), entity_spans=spans) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=entities, entity_spans=tuple(spans)) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=[0], entity_spans=spans) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=entities, entity_spans=[0]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)]) + + def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer(task="entity_classification") + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + span = (15, 34) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[span, span]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[0]) + + def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer(task="entity_pair_classification") + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + # head and tail information + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[0, 0]) + + def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer(task="entity_span_classification") + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[0, 0, 0]) + @require_torch class LukeTokenizerIntegrationTests(unittest.TestCase): diff --git a/tests/test_tokenization_mluke.py b/tests/test_tokenization_mluke.py new file mode 100644 index 0000000000..f869bc3292 --- /dev/null +++ b/tests/test_tokenization_mluke.py @@ -0,0 +1,666 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from typing import Tuple + +from transformers.models.mluke.tokenization_mluke import MLukeTokenizer +from transformers.testing_utils import require_torch, slow + +from .test_tokenization_common import TokenizerTesterMixin + + +class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = MLukeTokenizer + test_rust_tokenizer = False + from_pretrained_kwargs = {"cls_token": ""} + + def setUp(self): + super().setUp() + + self.special_tokens_map = {"entity_token_1": "", "entity_token_2": ""} + + def get_tokenizer(self, task=None, **kwargs): + kwargs.update(self.special_tokens_map) + kwargs.update({"task": task}) + return self.tokenizer_class.from_pretrained("studio-ousia/mluke-base", **kwargs) + + def get_input_output_texts(self, tokenizer): + input_text = "lower newer" + output_text = "lower newer" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base") + text = "lower newer" + spm_tokens = ["▁lower", "▁new", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, spm_tokens) + + input_tokens = tokens + [tokenizer.unk_token] + input_spm_tokens = [92319, 3525, 56, 3] + self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_spm_tokens) + + def mluke_dict_integration_testing(self): + tokenizer = self.get_tokenizer() + + self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [35378, 8999, 38]) + self.assertListEqual( + tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False), + [35378, 8999, 38, 33273, 11676, 604, 365, 21392, 201, 1819], + ) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base") + + text = tokenizer.encode("sequence builders", add_special_tokens=False) + text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) + + encoded_text_from_decode = tokenizer.encode( + "sequence builders", add_special_tokens=True, add_prefix_space=False + ) + encoded_pair_from_decode = tokenizer.encode( + "sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False + ) + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + self.assertEqual(encoded_sentence, encoded_text_from_decode) + self.assertEqual(encoded_pair, encoded_pair_from_decode) + + def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]: + txt = "Beyonce lives in Los Angeles" + ids = tokenizer.encode(txt, add_special_tokens=False) + return txt, ids + + def test_pretokenized_inputs(self): + pass + + def test_embeded_special_tokens(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + sentence = "A, AllenNLP sentence." + tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True) + tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True) + + # token_type_ids should put 0 everywhere + self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) + + # token_type_ids should put 0 everywhere + self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) + + # attention_mask should put 1 everywhere, so sum over length should be 1 + self.assertEqual( + sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]), + ) + + tokens_p_str = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"]) + + # Rust correctly handles the space before the mask while python doesnt + self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2]) + + self.assertSequenceEqual( + tokens_p_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""] + ) + + def test_padding_entity_inputs(self): + tokenizer = self.get_tokenizer() + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + span = (15, 34) + pad_id = tokenizer.entity_vocab["[PAD]"] + mask_id = tokenizer.entity_vocab["[MASK]"] + + encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True) + self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]]) + + # test with a sentence with no entity + encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True) + self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]]) + + def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer() + + sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and Afghanistan." + entities = ["en:ISO 639-3"] + spans = [(0, 9)] + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=tuple(entities), entity_spans=spans) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=entities, entity_spans=tuple(spans)) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=[0], entity_spans=spans) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=entities, entity_spans=[0]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)]) + + def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer(task="entity_classification") + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + span = (15, 34) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[span, span]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[0]) + + def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer(task="entity_pair_classification") + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + # head and tail information + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[0, 0]) + + def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self): + tokenizer = self.get_tokenizer(task="entity_span_classification") + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[]) + + with self.assertRaises(ValueError): + tokenizer(sentence, entity_spans=[0, 0, 0]) + + +@require_torch +class MLukeTokenizerIntegrationTests(unittest.TestCase): + tokenizer_class = MLukeTokenizer + from_pretrained_kwargs = {"cls_token": ""} + + @classmethod + def setUpClass(cls): + cls.tokenizer = MLukeTokenizer.from_pretrained("studio-ousia/mluke-base", return_token_type_ids=True) + cls.entity_classification_tokenizer = MLukeTokenizer.from_pretrained( + "studio-ousia/mluke-base", return_token_type_ids=True, task="entity_classification" + ) + cls.entity_pair_tokenizer = MLukeTokenizer.from_pretrained( + "studio-ousia/mluke-base", return_token_type_ids=True, task="entity_pair_classification" + ) + + cls.entity_span_tokenizer = MLukeTokenizer.from_pretrained( + "studio-ousia/mluke-base", return_token_type_ids=True, task="entity_span_classification" + ) + + def test_single_text_no_padding_or_truncation(self): + tokenizer = self.tokenizer + sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"] + spans = [(0, 9), (59, 63), (68, 75), (77, 88)] + + encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + " ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][17], spaces_between_special_tokens=False), "Iran") + self.assertEqual( + tokenizer.decode(encoding["input_ids"][19:25], spaces_between_special_tokens=False), "アフガニスタン" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][26], spaces_between_special_tokens=False), "Afghanistan" + ) + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["en:ISO 639-3"], + tokenizer.entity_vocab["[UNK]"], + tokenizer.entity_vocab["ja:アフガニスタン"], + tokenizer.entity_vocab["en:Afghanistan"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] + ) + # fmt: on + + def test_single_text_only_entity_spans_no_padding_or_truncation(self): + tokenizer = self.tokenizer + + sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"] + spans = [(0, 9), (59, 63), (68, 75), (77, 88)] + + encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + " ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][17], spaces_between_special_tokens=False), "Iran") + self.assertEqual( + tokenizer.decode(encoding["input_ids"][20:25], spaces_between_special_tokens=False), "アフガニスタン" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][26], spaces_between_special_tokens=False), "Afghanistan" + ) + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["en:ISO 639-3"], + tokenizer.entity_vocab["[UNK]"], + tokenizer.entity_vocab["ja:アフガニスタン"], + tokenizer.entity_vocab["en:Afghanistan"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] + ) + # fmt: on + + def test_single_text_padding_pytorch_tensors(self): + tokenizer = self.tokenizer + + sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"] + spans = [(0, 9), (59, 63), (68, 75), (77, 88)] + + encoding = tokenizer( + sentence, + entities=entities, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + + def test_text_pair_no_padding_or_truncation(self): + tokenizer = self.tokenizer + + sentence = "ISO 639-3 uses the code fas" + sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + entities = ["en:ISO 639-3"] + entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"] + spans = [(0, 9)] + spans_pair = [(31, 35), (40, 47), (49, 60)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + ) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + " ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][19], spaces_between_special_tokens=False), "Iran") + self.assertEqual( + tokenizer.decode(encoding["input_ids"][21:27], spaces_between_special_tokens=False), "アフガニスタン" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][28], spaces_between_special_tokens=False), "Afghanistan" + ) + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["en:ISO 639-3"], + tokenizer.entity_vocab["[UNK]"], + tokenizer.entity_vocab["ja:アフガニスタン"], + tokenizer.entity_vocab["en:Afghanistan"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [21, 22, 23, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [28, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] + ) + # fmt: on + + def test_text_pair_only_entity_spans_no_padding_or_truncation(self): + tokenizer = self.tokenizer + + sentence = "ISO 639-3 uses the code fas" + sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + entities = ["en:ISO 639-3"] + entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"] + spans = [(0, 9)] + spans_pair = [(31, 35), (40, 47), (49, 60)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + ) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + " ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][19], spaces_between_special_tokens=False), "Iran") + self.assertEqual( + tokenizer.decode(encoding["input_ids"][21:27], spaces_between_special_tokens=False), "アフガニスタン" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][28], spaces_between_special_tokens=False), "Afghanistan" + ) + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["en:ISO 639-3"], + tokenizer.entity_vocab["[UNK]"], + tokenizer.entity_vocab["ja:アフガニスタン"], + tokenizer.entity_vocab["en:Afghanistan"], + ], + ) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [21, 22, 23, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [28, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] + ) + # fmt: on + + def test_text_pair_padding_pytorch_tensors(self): + tokenizer = self.tokenizer + + sentence = "ISO 639-3 uses the code fas" + sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)." + entities = ["en:ISO 639-3"] + entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"] + spans = [(0, 9)] + spans_pair = [(31, 35), (40, 47), (49, 60)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + padding="max_length", + max_length=40, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 40)) + self.assertEqual(encoding["attention_mask"].shape, (1, 40)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 40)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + + def test_entity_classification_no_padding_or_truncation(self): + tokenizer = self.entity_classification_tokenizer + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + span = (15, 34) + + encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True) + + # test words + self.assertEqual(len(encoding["input_ids"]), 23) + self.assertEqual(len(encoding["attention_mask"]), 23) + self.assertEqual(len(encoding["token_type_ids"]), 23) + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + " Japanese is anEast Asian languagespoken by about 128 million people, primarily in Japan.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][4:9], spaces_between_special_tokens=False), + "East Asian language", + ) + + # test entities + mask_id = tokenizer.entity_vocab["[MASK]"] + self.assertEqual(encoding["entity_ids"], [mask_id]) + self.assertEqual(encoding["entity_attention_mask"], [1]) + self.assertEqual(encoding["entity_token_type_ids"], [0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [[4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]] + ) + # fmt: on + + def test_entity_classification_padding_pytorch_tensors(self): + tokenizer = self.entity_classification_tokenizer + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + span = (15, 34) + + encoding = tokenizer( + sentence, entity_spans=[span], return_token_type_ids=True, padding="max_length", return_tensors="pt" + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 512)) + self.assertEqual(encoding["attention_mask"].shape, (1, 512)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 512)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 1)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 1)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 1)) + self.assertEqual( + encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length) + ) + + def test_entity_pair_classification_no_padding_or_truncation(self): + tokenizer = self.entity_pair_tokenizer + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + # head and tail information + spans = [(0, 8), (84, 89)] + + encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Japaneseis an East Asian language spoken by about 128 million people, primarily inJapan.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][1:4], spaces_between_special_tokens=False), + "Japanese", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][20:23], spaces_between_special_tokens=False), "Japan" + ) + + mask_id = tokenizer.entity_vocab["[MASK]"] + mask2_id = tokenizer.entity_vocab["[MASK2]"] + self.assertEqual(encoding["entity_ids"], [mask_id, mask2_id]) + self.assertEqual(encoding["entity_attention_mask"], [1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [20, 21, 22, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] + ] + ) + # fmt: on + + def test_entity_pair_classification_padding_pytorch_tensors(self): + tokenizer = self.entity_pair_tokenizer + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + # head and tail information + spans = [(0, 8), (84, 89)] + + encoding = tokenizer( + sentence, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 2)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 2)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 2)) + self.assertEqual( + encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length) + ) + + def test_entity_span_classification_no_padding_or_truncation(self): + tokenizer = self.entity_span_tokenizer + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + spans = [(0, 8), (15, 34), (84, 89)] + + encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + " Japanese is an East Asian language spoken by about 128 million people, primarily in Japan.", + ) + + mask_id = tokenizer.entity_vocab["[MASK]"] + self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id]) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + # fmt: off + self.assertEqual( + encoding["entity_position_ids"], + [ + [1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]] + ) + # fmt: on + self.assertEqual(encoding["entity_start_positions"], [1, 4, 18]) + self.assertEqual(encoding["entity_end_positions"], [1, 6, 18]) + + def test_entity_span_classification_padding_pytorch_tensors(self): + tokenizer = self.entity_span_tokenizer + + sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan." + spans = [(0, 8), (15, 34), (84, 89)] + + encoding = tokenizer( + sentence, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + self.assertEqual(encoding["entity_start_positions"].shape, (1, 16)) + self.assertEqual(encoding["entity_end_positions"].shape, (1, 16)) diff --git a/utils/check_repo.py b/utils/check_repo.py index bf32af687e..2cb204313e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -116,6 +116,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "DPRReader", "FlaubertForQuestionAnswering", "GPT2DoubleHeadsModel", + "LukeForMaskedLM", "LukeForEntityClassification", "LukeForEntityPairClassification", "LukeForEntitySpanClassification",