From 841d979190319098adc8101f9820a02ee3be4c8b Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Wed, 19 Jan 2022 22:19:36 +0800 Subject: [PATCH] Add FastTokenizer to REALM (#15211) * Remove BertTokenizer abstraction * Add FastTokenizer to REALM * Fix config archive map * Fix copies * Update realm.mdx * Apply suggestions from code review --- docs/source/index.mdx | 2 +- docs/source/model_doc/realm.mdx | 5 + src/transformers/__init__.py | 2 + src/transformers/convert_slow_tokenizer.py | 1 + src/transformers/models/realm/__init__.py | 5 + .../models/realm/configuration_realm.py | 16 +- .../models/realm/tokenization_realm.py | 508 +++++++++++++++++- .../models/realm/tokenization_realm_fast.py | 298 ++++++++++ .../utils/dummy_tokenizers_objects.py | 7 + tests/test_tokenization_realm.py | 26 +- 10 files changed, 824 insertions(+), 46 deletions(-) create mode 100644 src/transformers/models/realm/tokenization_realm_fast.py diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 2498e0ce3c..7306ff1a4d 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -246,7 +246,7 @@ Flax), PyTorch, and/or TensorFlow. | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | -| Realm | ✅ | ❌ | ✅ | ❌ | ❌ | +| Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/model_doc/realm.mdx b/docs/source/model_doc/realm.mdx index 9f87c8f16b..f96e322ebf 100644 --- a/docs/source/model_doc/realm.mdx +++ b/docs/source/model_doc/realm.mdx @@ -49,6 +49,11 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi - save_vocabulary - batch_encode_candidates +## RealmTokenizerFast + +[[autodoc]] RealmTokenizerFast + - batch_encode_candidates + ## RealmRetriever [[autodoc]] RealmRetriever diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9d6a006693..137f7ff939 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -419,6 +419,7 @@ else: # tokenizers-backed objects if is_tokenizers_available(): # Fast tokenizers + _import_structure["models.realm"].append("RealmTokenizerFast") _import_structure["models.fnet"].append("FNetTokenizerFast") _import_structure["models.roformer"].append("RoFormerTokenizerFast") _import_structure["models.clip"].append("CLIPTokenizerFast") @@ -2542,6 +2543,7 @@ if TYPE_CHECKING: from .models.mt5 import MT5TokenizerFast from .models.openai import OpenAIGPTTokenizerFast from .models.pegasus import PegasusTokenizerFast + from .models.realm import RealmTokenizerFast from .models.reformer import ReformerTokenizerFast from .models.rembert import RemBertTokenizerFast from .models.retribert import RetriBertTokenizerFast diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index e9611fdca6..8ebf9e2496 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -942,6 +942,7 @@ SLOW_TO_FAST_CONVERTERS = { "MobileBertTokenizer": BertConverter, "OpenAIGPTTokenizer": OpenAIGPTConverter, "PegasusTokenizer": PegasusConverter, + "RealmTokenizer": BertConverter, "ReformerTokenizer": ReformerConverter, "RemBertTokenizer": RemBertConverter, "RetriBertTokenizer": BertConverter, diff --git a/src/transformers/models/realm/__init__.py b/src/transformers/models/realm/__init__.py index 8fe1b83144..41fcce7be7 100644 --- a/src/transformers/models/realm/__init__.py +++ b/src/transformers/models/realm/__init__.py @@ -25,6 +25,8 @@ _import_structure = { "tokenization_realm": ["RealmTokenizer"], } +if is_tokenizers_available(): + _import_structure["tokenization_realm_fast"] = ["RealmTokenizerFast"] if is_torch_available(): _import_structure["modeling_realm"] = [ @@ -44,6 +46,9 @@ if TYPE_CHECKING: from .configuration_realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig from .tokenization_realm import RealmTokenizer + if is_tokenizers_available(): + from .tokenization_realm import RealmTokenizerFast + if is_torch_available(): from .modeling_realm import ( REALM_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/realm/configuration_realm.py b/src/transformers/models/realm/configuration_realm.py index 49975c1a05..5762945515 100644 --- a/src/transformers/models/realm/configuration_realm.py +++ b/src/transformers/models/realm/configuration_realm.py @@ -21,14 +21,14 @@ from ...utils import logging logger = logging.get_logger(__name__) REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json", - "realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json", - "realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json", - "realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json", - "realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json", - "realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json", - "realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json", - "realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json", + "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json", + "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json", + "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json", + "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json", + "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json", + "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json", + "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json", + "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json", # See all REALM models at https://huggingface.co/models?filter=realm } diff --git a/src/transformers/models/realm/tokenization_realm.py b/src/transformers/models/realm/tokenization_realm.py index 9d7b72ac89..571cc7c198 100644 --- a/src/transformers/models/realm/tokenization_realm.py +++ b/src/transformers/models/realm/tokenization_realm.py @@ -14,10 +14,15 @@ # limitations under the License. """Tokenization classes for REALM.""" +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + from ...file_utils import PaddingStrategy +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...tokenization_utils_base import BatchEncoding from ...utils import logging -from ..bert.tokenization_bert import BertTokenizer logger = logging.get_logger(__name__) @@ -26,54 +31,193 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", - "realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", - "realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", - "realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", - "realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", - "realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", - "realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", - "realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", + "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", + "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", + "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", + "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "realm-cc-news-pretrained-embedder": 512, - "realm-cc-news-pretrained-encoder": 512, - "realm-cc-news-pretrained-scorer": 512, - "realm-cc-news-pretrained-openqa": 512, - "realm-orqa-nq-openqa": 512, - "realm-orqa-nq-reader": 512, - "realm-orqa-wq-openqa": 512, - "realm-orqa-wq-reader": 512, + "qqaatw/realm-cc-news-pretrained-embedder": 512, + "qqaatw/realm-cc-news-pretrained-encoder": 512, + "qqaatw/realm-cc-news-pretrained-scorer": 512, + "qqaatw/realm-cc-news-pretrained-openqa": 512, + "qqaatw/realm-orqa-nq-openqa": 512, + "qqaatw/realm-orqa-nq-reader": 512, + "qqaatw/realm-orqa-wq-openqa": 512, + "qqaatw/realm-orqa-wq-reader": 512, } PRETRAINED_INIT_CONFIGURATION = { - "realm-cc-news-pretrained-embedder": {"do_lower_case": True}, - "realm-cc-news-pretrained-encoder": {"do_lower_case": True}, - "realm-cc-news-pretrained-scorer": {"do_lower_case": True}, - "realm-cc-news-pretrained-openqa": {"do_lower_case": True}, - "realm-orqa-nq-openqa": {"do_lower_case": True}, - "realm-orqa-nq-reader": {"do_lower_case": True}, - "realm-orqa-wq-openqa": {"do_lower_case": True}, - "realm-orqa-wq-reader": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, + "qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True}, + "qqaatw/realm-orqa-nq-reader": {"do_lower_case": True}, + "qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True}, + "qqaatw/realm-orqa-wq-reader": {"do_lower_case": True}, } -class RealmTokenizer(BertTokenizer): +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class RealmTokenizer(PreTrainedTokenizer): r""" Construct a REALM tokenizer. [`RealmTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting and wordpiece. - Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters. + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents: (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs + ): + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string def batch_encode_candidates(self, text, **kwargs): r""" @@ -147,3 +291,311 @@ class RealmTokenizer(BertTokenizer): output_data = dict((key, item) for key, item in output_data.items() if len(item) != 0) return BatchEncoding(output_data, tensor_type=return_tensors) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A REALM sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer(object): + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see + WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens diff --git a/src/transformers/models/realm/tokenization_realm_fast.py b/src/transformers/models/realm/tokenization_realm_fast.py new file mode 100644 index 0000000000..e78dc4f990 --- /dev/null +++ b/src/transformers/models/realm/tokenization_realm_fast.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Copyright 2022 The REALM authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for REALM.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ...file_utils import PaddingStrategy +from ...tokenization_utils_base import BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_realm import RealmTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", + "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", + "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", + "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", + "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", + "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", + }, + "tokenizer_file": { + "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont", + "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json", + "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json", + "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json", + "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/tokenizer.json", + "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/tokenizer.json", + "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/tokenizer.json", + "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/tokenizer.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "qqaatw/realm-cc-news-pretrained-embedder": 512, + "qqaatw/realm-cc-news-pretrained-encoder": 512, + "qqaatw/realm-cc-news-pretrained-scorer": 512, + "qqaatw/realm-cc-news-pretrained-openqa": 512, + "qqaatw/realm-orqa-nq-openqa": 512, + "qqaatw/realm-orqa-nq-reader": 512, + "qqaatw/realm-orqa-wq-openqa": 512, + "qqaatw/realm-orqa-wq-reader": 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + "qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, + "qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, + "qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True}, + "qqaatw/realm-orqa-nq-reader": {"do_lower_case": True}, + "qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True}, + "qqaatw/realm-orqa-wq-reader": {"do_lower_case": True}, +} + + +class RealmTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" REALM tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + [`RealmTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation + splitting and wordpiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = RealmTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def batch_encode_candidates(self, text, **kwargs): + r""" + Encode a batch of text or text pair. This method is similar to regular __call__ method but has the following + differences: + + 1. Handle additional num_candidate axis. (batch_size, num_candidates, text) + 2. Always pad the sequences to *max_length*. + 3. Must specify *max_length* in order to stack packs of candidates into a batch. + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + text (`List[List[str]]`): + The batch of sequences to be encoded. Each sequence must be in this format: (batch_size, + num_candidates, text). + text_pair (`List[List[str]]`, *optional*): + The batch of sequences to be encoded. Each sequence must be in this format: (batch_size, + num_candidates, text). + **kwargs: + Keyword arguments of the __call__ method. + + Returns: + [`BatchEncoding`]: Encoded text or text pair. + + Example: + + ```python + >>> from transformers import RealmTokenizerFast + + >>> # batch_size = 2, num_candidates = 2 + >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] + + >>> tokenizer = RealmTokenizerFast.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") + >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") + ```""" + + # Always using a fixed sequence length to encode in order to stack candidates into a batch. + kwargs["padding"] = PaddingStrategy.MAX_LENGTH + + batch_text = text + batch_text_pair = kwargs.pop("text_pair", None) + return_tensors = kwargs.pop("return_tensors", None) + + output_data = { + "input_ids": [], + "attention_mask": [], + "token_type_ids": [], + } + + for idx, candidate_text in enumerate(batch_text): + if batch_text_pair is not None: + candidate_text_pair = batch_text_pair[idx] + else: + candidate_text_pair = None + + encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs) + + encoded_input_ids = encoded_candidates.get("input_ids") + encoded_attention_mask = encoded_candidates.get("attention_mask") + encoded_token_type_ids = encoded_candidates.get("token_type_ids") + + if encoded_input_ids is not None: + output_data["input_ids"].append(encoded_input_ids) + if encoded_attention_mask is not None: + output_data["attention_mask"].append(encoded_attention_mask) + if encoded_token_type_ids is not None: + output_data["token_type_ids"].append(encoded_token_type_ids) + + output_data = dict((key, item) for key, item in output_data.items() if len(item) != 0) + + return BatchEncoding(output_data, tensor_type=return_tensors) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A REALM sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 28897493ce..488dc8928c 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -234,6 +234,13 @@ class PegasusTokenizerFast(metaclass=DummyObject): requires_backends(self, ["tokenizers"]) +class RealmTokenizerFast(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class ReformerTokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/test_tokenization_realm.py b/tests/test_tokenization_realm.py index 53f22dd7f2..95e0d720f8 100644 --- a/tests/test_tokenization_realm.py +++ b/tests/test_tokenization_realm.py @@ -16,6 +16,7 @@ import os import unittest +from transformers import RealmTokenizerFast from transformers.models.bert.tokenization_bert import ( VOCAB_FILES_NAMES, BasicTokenizer, @@ -34,8 +35,8 @@ from .test_tokenization_common import TokenizerTesterMixin, filter_non_english class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = RealmTokenizer - rust_tokenizer_class = None - test_rust_tokenizer = False + rust_tokenizer_class = RealmTokenizerFast + test_rust_tokenizer = True space_between_special_tokens = True from_pretrained_filter = filter_non_english @@ -301,14 +302,21 @@ class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @slow def test_batch_encode_candidates(self): - tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] - text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] + encoded_sentence_r = tokenizer_r.batch_encode_candidates(text, max_length=10, return_tensors="np") + encoded_sentence_p = tokenizer_p.batch_encode_candidates(text, max_length=10, return_tensors="np") - encoded_sentence = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") + expected_shape = (2, 2, 10) - expected_shape = (2, 2, 10) + self.assertEqual(encoded_sentence_r["input_ids"].shape, expected_shape) + self.assertEqual(encoded_sentence_r["attention_mask"].shape, expected_shape) + self.assertEqual(encoded_sentence_r["token_type_ids"].shape, expected_shape) - assert encoded_sentence["input_ids"].shape == expected_shape - assert encoded_sentence["attention_mask"].shape == expected_shape - assert encoded_sentence["token_type_ids"].shape == expected_shape + self.assertEqual(encoded_sentence_p["input_ids"].shape, expected_shape) + self.assertEqual(encoded_sentence_p["attention_mask"].shape, expected_shape) + self.assertEqual(encoded_sentence_p["token_type_ids"].shape, expected_shape)