[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests (#4510)
* Use tokenizers pre-tokenized pipeline * failing pretrokenized test * Fix is_pretokenized in python * add pretokenized tests * style and quality * better tests for batched pretokenized inputs * tokenizers clean up - new padding_strategy - split the files * [HUGE] refactoring tokenizers - padding - truncation - tests * style and quality * bump up requied tokenizers version to 0.8.0-rc1 * switched padding/truncation API - simpler better backward compat * updating tests for custom tokenizers * style and quality - tests on pad * fix QA pipeline * fix backward compatibility for max_length only * style and quality * Various cleans up - add verbose * fix tests * update docstrings * Fix tests * Docs reformatted * __call__ method documented Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -17,12 +17,14 @@ The base classes ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` impleme
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PreTrainedTokenizer
|
||||
:special-members: __call__
|
||||
:members:
|
||||
|
||||
``PreTrainedTokenizerFast``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PreTrainedTokenizerFast
|
||||
:special-members: __call__
|
||||
:members:
|
||||
|
||||
``BatchEncoding``
|
||||
|
||||
@@ -3,8 +3,6 @@ import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers.tokenization_utils import trim_batch
|
||||
|
||||
|
||||
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
examples = []
|
||||
@@ -17,6 +15,17 @@ def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return
|
||||
return examples
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids, pad_token_id, attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
2
setup.py
2
setup.py
@@ -108,7 +108,7 @@ setup(
|
||||
packages=find_packages("src"),
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"tokenizers == 0.7.0",
|
||||
"tokenizers == 0.8.0-rc1",
|
||||
# dataclasses for Python versions that don't have it
|
||||
"dataclasses;python_version<'3.7'",
|
||||
# utilities from PyPA to e.g. compare versions
|
||||
|
||||
@@ -133,13 +133,16 @@ from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
from .tokenization_t5 import T5Tokenizer
|
||||
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
|
||||
from .tokenization_utils import (
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_base import (
|
||||
BatchEncoding,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
CharSpan,
|
||||
PreTrainedTokenizerBase,
|
||||
SpecialTokensMixin,
|
||||
TensorType,
|
||||
TokenSpan,
|
||||
)
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||
|
||||
@@ -1213,7 +1213,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
|
||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||
|
||||
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||
input_ids = tokenizer(choices, add_special_tokens=True, return_tensors='tf', truncation=True, padding=True)[None, :] # Batch size 1, 2 choices
|
||||
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||
outputs = model(input_ids, labels=labels)
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ from typing import List, Optional
|
||||
|
||||
from tokenizers import BertWordPieceTokenizer
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,7 +23,9 @@ from functools import lru_cache
|
||||
import regex as re
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -346,3 +348,24 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||
unk_token=unk_token,
|
||||
**kwargs,
|
||||
)
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||
|
||||
is_pretokenized = kwargs.get("is_pretokenized", False)
|
||||
assert self.add_prefix_space or not is_pretokenized, (
|
||||
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
|
||||
"to use it with pretokenized inputs."
|
||||
)
|
||||
|
||||
return super()._batch_encode_plus(*args, **kwargs)
|
||||
|
||||
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||
|
||||
is_pretokenized = kwargs.get("is_pretokenized", False)
|
||||
assert self.add_prefix_space or not is_pretokenized, (
|
||||
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
|
||||
"to use it with pretokenized inputs."
|
||||
)
|
||||
|
||||
return super()._encode_plus(*args, **kwargs)
|
||||
|
||||
@@ -23,7 +23,8 @@ import re
|
||||
from tokenizers import CharBPETokenizer
|
||||
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +35,8 @@ from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
|
||||
from tokenizers.processors import BertProcessing
|
||||
|
||||
from .file_utils import cached_path, is_torch_available
|
||||
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1775
src/transformers/tokenization_utils_base.py
Normal file
1775
src/transformers/tokenization_utils_base.py
Normal file
File diff suppressed because it is too large
Load Diff
476
src/transformers/tokenization_utils_fast.py
Normal file
476
src/transformers/tokenization_utils_fast.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library).
|
||||
For slow (python) tokenizers see tokenization_utils.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from tokenizers import AddedToken as AddedTokenFast
|
||||
from tokenizers import Encoding as EncodingFast
|
||||
from tokenizers.decoders import Decoder as DecoderFast
|
||||
from tokenizers.implementations import BaseTokenizer as BaseTokenizerFast
|
||||
|
||||
from .tokenization_utils_base import (
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
PreTokenizedInputPair,
|
||||
PreTrainedTokenizerBase,
|
||||
TextInput,
|
||||
TextInputPair,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
""" Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
|
||||
|
||||
Inherit from PreTrainedTokenizer.
|
||||
|
||||
Handle all the shared methods for tokenization and special tokens as well as methods
|
||||
downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
|
||||
|
||||
This class also contain the added tokens in a unified way on top of all tokenizers so we don't
|
||||
have to handle the specific vocabulary augmentation methods of the various underlying
|
||||
dictionary structures (BPE, sentencepiece...).
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
|
||||
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
|
||||
required by the model, and as associated values, the filename for saving the associated file (string).
|
||||
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
|
||||
being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
|
||||
`short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
|
||||
associated pretrained vocabulary file.
|
||||
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
|
||||
models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
|
||||
model has no maximum input size.
|
||||
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
|
||||
pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
|
||||
``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
|
||||
``from_pretrained()`` method.
|
||||
|
||||
Args:
|
||||
- ``tokenizer`` (`BaseTokenizerFast`): A Fast tokenizer from the HuggingFace tokenizer library (in low level Rust language)
|
||||
- ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
|
||||
When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
|
||||
model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
|
||||
no associated max_length can be found in ``max_model_input_sizes``.
|
||||
- ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
|
||||
Should be selected between ['right', 'left']
|
||||
- ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
|
||||
model ("token_type_ids", "attention_mask"...).
|
||||
- ``bos_token``: (`Optional`) string: a beginning of sentence token.
|
||||
Will be associated to ``self.bos_token`` and ``self.bos_token_id``
|
||||
- ``eos_token``: (`Optional`) string: an end of sentence token.
|
||||
Will be associated to ``self.eos_token`` and ``self.eos_token_id``
|
||||
- ``unk_token``: (`Optional`) string: an unknown token.
|
||||
Will be associated to ``self.unk_token`` and ``self.unk_token_id``
|
||||
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
|
||||
Will be associated to ``self.sep_token`` and ``self.sep_token_id``
|
||||
- ``pad_token``: (`Optional`) string: a padding token.
|
||||
Will be associated to ``self.pad_token`` and ``self.pad_token_id``
|
||||
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
|
||||
leveraging self-attention along the full depth of the model).
|
||||
Will be associated to ``self.cls_token`` and ``self.cls_token_id``
|
||||
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
|
||||
modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
|
||||
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
|
||||
Adding all special tokens here ensure they won't be split by the tokenization process.
|
||||
Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
|
||||
|
||||
|
||||
.. automethod:: __call__
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: BaseTokenizerFast, **kwargs):
|
||||
if not isinstance(tokenizer, BaseTokenizerFast):
|
||||
raise ValueError(
|
||||
"Tokenizer should be an instance of a Tokenizer " "provided by HuggingFace tokenizers library."
|
||||
)
|
||||
self._tokenizer: BaseTokenizerFast = tokenizer
|
||||
|
||||
# We call this after having initialized the backend tokenizer because we update it.
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._tokenizer.get_vocab_size(with_added_tokens=False)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self._tokenizer.get_vocab(with_added_tokens=True)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._tokenizer.get_vocab_size(with_added_tokens=True)
|
||||
|
||||
@property
|
||||
def backend_tokenizer(self) -> BaseTokenizerFast:
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def decoder(self) -> DecoderFast:
|
||||
return self._tokenizer._tokenizer.decoder
|
||||
|
||||
def _maybe_update_backend(self, value):
|
||||
""" Update the backend fast tokenizer.
|
||||
Override method from base class SpecialTokensMixin """
|
||||
self._tokenizer.add_special_tokens(value)
|
||||
|
||||
def _convert_encoding(
|
||||
self,
|
||||
encoding: EncodingFast,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
""" Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
|
||||
|
||||
Overflowing tokens are converted to additional examples (like batches) so the output values of
|
||||
the dict are lists (overflows) of lists (tokens).
|
||||
|
||||
Output shape: (overflows, sequence length)
|
||||
"""
|
||||
if return_token_type_ids is None:
|
||||
return_token_type_ids = "token_type_ids" in self.model_input_names
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
if return_overflowing_tokens and encoding.overflowing is not None:
|
||||
encodings = [encoding] + encoding.overflowing
|
||||
else:
|
||||
encodings = [encoding]
|
||||
|
||||
encoding_dict = defaultdict(list)
|
||||
for e in encodings:
|
||||
encoding_dict["input_ids"].append(e.ids)
|
||||
|
||||
if return_token_type_ids:
|
||||
encoding_dict["token_type_ids"].append(e.type_ids)
|
||||
if return_attention_mask:
|
||||
encoding_dict["attention_mask"].append(e.attention_mask)
|
||||
if return_special_tokens_mask:
|
||||
encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
|
||||
if return_offsets_mapping:
|
||||
encoding_dict["offset_mapping"].append(e.offsets)
|
||||
|
||||
return encoding_dict
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Converts a token string (or a sequence of tokens) in a single integer id
|
||||
(or a sequence of ids), using the vocabulary.
|
||||
"""
|
||||
if tokens is None:
|
||||
return None
|
||||
|
||||
if isinstance(tokens, str):
|
||||
return self._convert_token_to_id_with_added_voc(tokens)
|
||||
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||
return ids
|
||||
|
||||
def _convert_token_to_id_with_added_voc(self, token: int) -> str:
|
||||
index = self._tokenizer.token_to_id(token)
|
||||
if index is None:
|
||||
return self.unk_token_id
|
||||
return index
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
||||
return self._tokenizer.id_to_token(int(index))
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[int], skip_special_tokens: bool = False) -> str:
|
||||
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def add_tokens(self, new_tokens: List[Union[str, AddedTokenFast]]) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
|
||||
Args:
|
||||
new_tokens: string or list of string or :class:`~transformers.AddedTokenFast`. Each string is a token to add.
|
||||
Tokens are only added if they are not already in the vocabulary. AddedTokenFast wrap a string token to
|
||||
let you personnalize it's behavior (Whether this token should only match against single word, whether
|
||||
this token should strip all potential whitespaces on the left side, Whether this token should strip
|
||||
all potential whitespaces on the right side...).
|
||||
|
||||
See details for :class:`~transformers.AddedToken` in HuggingFace tokenizers library.
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
"""
|
||||
if isinstance(new_tokens, str):
|
||||
new_tokens = [new_tokens]
|
||||
# TODO This should be done in tokenizers to be really clean.
|
||||
# Removing for now
|
||||
# tokens = []
|
||||
# for token in new_tokens:
|
||||
# if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
|
||||
# token = token.lower()
|
||||
# if token not in tokens:
|
||||
# tokens.append(token)
|
||||
return self._tokenizer.add_tokens(new_tokens)
|
||||
|
||||
def num_special_tokens_to_add(self, pair: bool = False) -> int:
|
||||
return self._tokenizer.num_special_tokens_to_add(pair)
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
||||
) -> Union[int, List[int]]:
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
return self._tokenizer.id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
index = int(index)
|
||||
if skip_special_tokens and index in self.all_special_ids:
|
||||
continue
|
||||
tokens.append(self._tokenizer.id_to_token(index))
|
||||
return tokens
|
||||
|
||||
def tokenize(
|
||||
self, text: TextInput, pair: Optional[TextInput] = None, add_special_tokens: bool = False
|
||||
) -> List[str]:
|
||||
return self._tokenizer.encode(text, pair, add_special_tokens=add_special_tokens).tokens
|
||||
|
||||
def set_truncation_and_padding(
|
||||
self, padding_strategy: PaddingStrategy, truncation_strategy: TruncationStrategy, max_length: int, stride: int,
|
||||
):
|
||||
""" This contextmanager is in charge of defining the truncation and the padding strategies for fast tokenizers
|
||||
(provided by HuggingFace tokenizers library) and restore the tokenizer settings afterwards.
|
||||
|
||||
This contextmanager assumes the provider tokenizer has no padding / truncation strategy
|
||||
before the managed section. If your tokenizer set a padding / truncation strategy before,
|
||||
then it will be reset to no padding/truncation when exiting the managed section.
|
||||
|
||||
Args:
|
||||
tokenizer (BaseTokenizerFast): The tokenizer which will be used
|
||||
max_length (int): The maximum size of the sequence
|
||||
stride (int): The stride to use when handling overflow
|
||||
strategy (str): Overflowing logic to use
|
||||
pad_to_max_length (bool): Boolean indicating if the output needs to be padded up to max_length
|
||||
padding_side (str): "left" or "right" indicating the direction the output sequence will be padded
|
||||
pad_token_id (int): The integer representation of the padding token to use
|
||||
pad_token_type_id (int): The integer representation of the padding token type to use
|
||||
pad_token (str): The string representation of the padding token to use
|
||||
|
||||
"""
|
||||
# Set truncation and padding on the backend tokenizer
|
||||
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
|
||||
self._tokenizer.enable_truncation(max_length, stride=stride, strategy=truncation_strategy.value)
|
||||
else:
|
||||
self._tokenizer.no_truncation()
|
||||
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
|
||||
self._tokenizer.enable_padding(
|
||||
length=max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None,
|
||||
direction=self.padding_side,
|
||||
pad_id=self.pad_token_id,
|
||||
pad_type_id=self.pad_token_type_id,
|
||||
pad_token=self.pad_token,
|
||||
)
|
||||
else:
|
||||
self._tokenizer.no_padding()
|
||||
|
||||
def _batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[
|
||||
List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair]
|
||||
],
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
is_pretokenized: bool = False,
|
||||
return_tensors: Optional[str] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_lengths: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
|
||||
if not isinstance(batch_text_or_text_pairs, list):
|
||||
raise ValueError(
|
||||
"batch_text_or_text_pairs has to be a list (got {})".format(type(batch_text_or_text_pairs))
|
||||
)
|
||||
|
||||
# Set the truncation and padding strategy and restore the initial configuration
|
||||
self.set_truncation_and_padding(
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
# Avoid thread overhead if only one example.
|
||||
if len(batch_text_or_text_pairs) == 1:
|
||||
if isinstance(batch_text_or_text_pairs[0], tuple):
|
||||
# We got a Tuple with a pair of sequences
|
||||
encodings = self._tokenizer.encode(
|
||||
*batch_text_or_text_pairs[0],
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_pretokenized=is_pretokenized,
|
||||
)
|
||||
else:
|
||||
# We got a single sequence
|
||||
encodings = self._tokenizer.encode(
|
||||
batch_text_or_text_pairs[0],
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_pretokenized=is_pretokenized,
|
||||
)
|
||||
encodings = [encodings]
|
||||
else:
|
||||
encodings = self._tokenizer.encode_batch(
|
||||
batch_text_or_text_pairs, add_special_tokens=add_special_tokens, is_pretokenized=is_pretokenized
|
||||
)
|
||||
|
||||
# Convert encoding to dict
|
||||
# `Tokens` has type: List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]]
|
||||
# with nested dimensions corresponding to batch, overflows, sequence length
|
||||
tokens = [
|
||||
self._convert_encoding(
|
||||
encoding=encoding,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
verbose=verbose,
|
||||
)
|
||||
for encoding in encodings
|
||||
]
|
||||
|
||||
# Convert the output to have dict[list] from list[dict]
|
||||
sanitized = {}
|
||||
for key in tokens[0].keys():
|
||||
# To List[List[List[int]]] of shape (batch, overflows, sequence length)
|
||||
stack = [e for item in tokens for e in item[key]]
|
||||
sanitized[key] = stack
|
||||
|
||||
# If returning overflowing tokens, we need to return a mapping
|
||||
# from the batch idx to the original sample
|
||||
if return_overflowing_tokens:
|
||||
overflow_to_sample_mapping = []
|
||||
for i, enc in enumerate(tokens):
|
||||
overflow_to_sample_mapping += [i] * len(enc["input_ids"])
|
||||
sanitized["overflow_to_sample_mapping"] = overflow_to_sample_mapping
|
||||
|
||||
return BatchEncoding(sanitized, encodings, tensor_type=return_tensors)
|
||||
|
||||
def _encode_plus(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput],
|
||||
text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
is_pretokenized: bool = False,
|
||||
return_tensors: Optional[bool] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
|
||||
batched_input = [(text, text_pair)] if text_pair else [text]
|
||||
batched_output = self._batch_encode_plus(
|
||||
batched_input,
|
||||
is_pretokenized=is_pretokenized,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Return tensor is None, then we can remove the leading batch axis
|
||||
# Overfolwing tokens are returned as a batch of output so we keep them in this case
|
||||
if return_tensors is None and not return_overflowing_tokens:
|
||||
batched_output = BatchEncoding(
|
||||
{
|
||||
key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
|
||||
for key, value in batched_output.items()
|
||||
},
|
||||
batched_output.encodings,
|
||||
)
|
||||
|
||||
return batched_output
|
||||
|
||||
def decode(
|
||||
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||
) -> str:
|
||||
text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def save_vocabulary(self, save_directory: str) -> Tuple[str]:
|
||||
if os.path.isdir(save_directory):
|
||||
files = self._tokenizer.save_model(save_directory)
|
||||
else:
|
||||
folder, file = os.path.split(os.path.abspath(save_directory))
|
||||
files = self._tokenizer.save_model(folder, name=file)
|
||||
|
||||
return tuple(files)
|
||||
@@ -51,7 +51,7 @@ class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -36,7 +36,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "this is a test"
|
||||
output_text = "this is a test"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -44,6 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
@@ -62,7 +64,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
@@ -72,7 +74,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
@@ -96,6 +98,25 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# With lower casing
|
||||
tokenizer = self.get_tokenizer(do_lower_case=True)
|
||||
rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
|
||||
|
||||
sequence = "UNwant\u00E9d,running"
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
|
||||
@@ -60,11 +60,26 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "こんにちは、世界。 \nこんばんは、世界。"
|
||||
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
|
||||
return input_text, output_text
|
||||
|
||||
def get_clean_sequence(self, tokenizer):
|
||||
input_text, output_text = self.get_input_output_texts(tokenizer)
|
||||
ids = tokenizer.encode(output_text, add_special_tokens=False)
|
||||
text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
|
||||
return text, ids
|
||||
|
||||
def test_pretokenized_inputs(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
@@ -157,11 +172,20 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "こんにちは、世界。 \nこんばんは、世界。"
|
||||
output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
|
||||
return input_text, output_text
|
||||
|
||||
def test_pretokenized_inputs(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
pass # TODO add if relevant
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "adapt react readapt apt"
|
||||
output_text = "adapt react readapt apt"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -27,7 +27,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter"])
|
||||
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter", "kwargs"])
|
||||
|
||||
|
||||
def filter_non_english(_: Tokenizer, pretrained_name: str):
|
||||
@@ -60,10 +60,10 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name)
|
||||
tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name)
|
||||
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p)
|
||||
self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
|
||||
self.fast_only(tokenizer_r)
|
||||
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p):
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
|
||||
# Check is_fast is set correctly
|
||||
self.assertFalse(tokenizer_p.is_fast)
|
||||
self.assertTrue(tokenizer_r.is_fast)
|
||||
@@ -75,6 +75,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
|
||||
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
|
||||
# TODO: enable for v3.0.0
|
||||
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
|
||||
@@ -90,6 +91,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_offsets_mapping(tokenizer_r)
|
||||
self.assert_add_special_tokens(tokenizer_r)
|
||||
self.assert_alignement_methods(tokenizer_r)
|
||||
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
||||
|
||||
def assert_alignement_methods(self, tokenizer_r):
|
||||
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
|
||||
@@ -169,7 +171,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
|
||||
self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
|
||||
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_p, tokenizer_r):
|
||||
def assert_tokenization_python_rust_equals(self, tokenizer_r, tokenizer_p):
|
||||
# Ensure basic input match
|
||||
input_p = tokenizer_p.encode_plus(self._data)
|
||||
input_r = tokenizer_r.encode_plus(self._data)
|
||||
@@ -184,18 +186,22 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
|
||||
|
||||
# Ensure truncation match
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512)
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512, truncation=True)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512, truncation=True)
|
||||
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assertSequenceEqual(input_p[key], input_r[key])
|
||||
|
||||
# Ensure truncation with stride match
|
||||
input_p = tokenizer_p.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
||||
input_r = tokenizer_r.encode_plus(self._data, max_length=512, stride=3, return_overflowing_tokens=True)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
|
||||
)
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
|
||||
)
|
||||
|
||||
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
|
||||
self.assertSequenceEqual(input_p[key], input_r[key])
|
||||
self.assertSequenceEqual(input_p[key], input_r[key][0])
|
||||
|
||||
def assert_num_special_tokens_to_add_equal(self, tokenizer_r, tokenizer_p):
|
||||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||||
@@ -274,9 +280,14 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
"""
|
||||
returned_tensor = "pt" if is_torch_available() else "tf"
|
||||
|
||||
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
||||
return
|
||||
|
||||
tokens = tokenizer.encode_plus(
|
||||
"HuggingFace is solving NLP one commit at a time",
|
||||
max_length=6,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -288,7 +299,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokens = tokenizer.batch_encode_plus(
|
||||
["HuggingFace is solving NLP one commit at a time"],
|
||||
max_length=6,
|
||||
pad_to_max_len=True,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -301,7 +313,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
tokens = tokenizer.batch_encode_plus(
|
||||
["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
|
||||
max_length=6,
|
||||
pad_to_max_len=True,
|
||||
padding=True,
|
||||
truncation="only_first",
|
||||
return_tensors=returned_tensor,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
@@ -310,6 +323,58 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(len(tokens[key].shape), 2)
|
||||
self.assertEqual(tokens[key].shape[-1], 6)
|
||||
|
||||
def assert_pretokenized_inputs(self, tokenizer_r, tokenizer_p):
|
||||
# Input string
|
||||
pretokenized_input_simple = "This is a sample input".split()
|
||||
pretokenized_input_pair = "This is a sample pair".split()
|
||||
|
||||
# Test encode for pretokenized inputs
|
||||
output_r = tokenizer_r.encode(pretokenized_input_simple, is_pretokenized=True)
|
||||
output_p = tokenizer_p.encode(pretokenized_input_simple, is_pretokenized=True)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
kwargs = {
|
||||
"is_pretokenized": True,
|
||||
"return_token_type_ids": True,
|
||||
"return_attention_mask": True,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": True,
|
||||
"return_offsets_mapping": False, # Not implemented in python tokenizers
|
||||
}
|
||||
# Test encode_plus for pretokenized inputs
|
||||
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
|
||||
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test batch_encode_plus for pretokenized inputs
|
||||
input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
|
||||
output_r = tokenizer_r.batch_encode_plus(input_batch, **kwargs)
|
||||
output_p = tokenizer_p.batch_encode_plus(input_batch, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test encode for pretokenized inputs pairs
|
||||
output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
|
||||
output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_pretokenized=True)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
# Test encode_plus for pretokenized inputs
|
||||
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
|
||||
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
# Test batch_encode_plus for pretokenized inputs
|
||||
input_batch_pair = ([pretokenized_input_simple, pretokenized_input_pair] * 2) + [
|
||||
pretokenized_input_simple + pretokenized_input_pair,
|
||||
pretokenized_input_pair,
|
||||
]
|
||||
output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **kwargs)
|
||||
output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **kwargs)
|
||||
for key in output_p.keys():
|
||||
self.assertEqual(output_p[key], output_r[key])
|
||||
|
||||
def assert_create_token_type_ids(self, tokenizer_r, tokenizer_p):
|
||||
input_simple = [1, 2, 3]
|
||||
input_pair = [1, 2, 3]
|
||||
@@ -357,17 +422,22 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
|
||||
|
||||
# Ensure we match max_length
|
||||
self.assertEqual(len(input_r), max_length), self.assertEqual(len(input_p), max_length)
|
||||
self.assertEqual(len(input_r), max_length)
|
||||
self.assertEqual(len(input_p), max_length)
|
||||
|
||||
# Ensure the number of padded tokens is the same
|
||||
padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
|
||||
padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
|
||||
self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
|
||||
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict):
|
||||
def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
|
||||
for i_r in input_r.values():
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), 15), self.assertEqual(len(i_r[1]), 15)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
|
||||
len(i_r[1]), max_length
|
||||
)
|
||||
|
||||
for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
|
||||
assert_padded_input_match(i_r, i_p, max_length)
|
||||
@@ -375,12 +445,19 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
|
||||
self.assertSequenceEqual(i_r, i_p)
|
||||
|
||||
# Simple input
|
||||
# Encode - Simple input
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
|
||||
# Encode - Pair input
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -388,14 +465,34 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r, input_p, max_length)
|
||||
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
|
||||
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
|
||||
assert_padded_input_match(input_r, input_p, len(input_r))
|
||||
|
||||
# Simple input
|
||||
# Encode_plus - Simple input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Encode_plus - Pair input
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
@@ -404,34 +501,130 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
input_p = tokenizer_p.encode_plus(
|
||||
"This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
|
||||
)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
|
||||
input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||||
|
||||
# Simple input
|
||||
# Batch_encode_plus - Simple input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
# Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="longest",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding=True,
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a simple input 1", "This is a simple input 2"], padding="longest"
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(["This is a simple input 1", "This is a simple input 2"], padding=True)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Batch_encode_plus - Pair input
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
max_length=15,
|
||||
pad_to_max_length=True,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
max_length=15,
|
||||
pad_to_max_length=True,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p)
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
padding=True,
|
||||
)
|
||||
input_p = tokenizer_p.batch_encode_plus(
|
||||
[
|
||||
("This is a simple input 1", "This is a simple input 2"),
|
||||
("This is a simple pair 1", "This is a simple pair 2"),
|
||||
],
|
||||
padding="longest",
|
||||
)
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_r = tokenizer_r.pad(input_r)
|
||||
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
|
||||
|
||||
# Using pad on single examples after tokenization
|
||||
input_r = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||||
|
||||
input_p = tokenizer_r.encode_plus("This is a input 1")
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_r = tokenizer_r.pad(input_r)
|
||||
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p)
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
|
||||
|
||||
# Using pad after tokenization
|
||||
input_r = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||||
|
||||
input_p = tokenizer_r.batch_encode_plus(
|
||||
["This is a input 1", "This is a much longer input whilch should be padded"]
|
||||
)
|
||||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||||
|
||||
assert_batch_padded_input_match(input_r, input_p, max_length)
|
||||
|
||||
def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
|
||||
# Checks it save with the same files
|
||||
@@ -503,8 +696,10 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[
|
||||
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english),
|
||||
Tokenizer("DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english),
|
||||
Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english, None),
|
||||
Tokenizer(
|
||||
"DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english, None
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -552,7 +747,7 @@ class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors)]
|
||||
[Tokenizer("Roberta", RobertaTokenizerFast, RobertaTokenizer, "vocab_file", filter_roberta_detectors, None)]
|
||||
)
|
||||
|
||||
def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
|
||||
@@ -580,10 +775,30 @@ class RobertaFastTokenizerTest(CommonFastTokenizerTest):
|
||||
|
||||
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
TOKENIZERS_CLASSES = [
|
||||
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None),
|
||||
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None),
|
||||
Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None, None),
|
||||
Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None, [("add_prefix_space", True)]),
|
||||
]
|
||||
|
||||
def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
|
||||
# Check is_fast is set correctly
|
||||
self.assertFalse(tokenizer_p.is_fast)
|
||||
self.assertTrue(tokenizer_r.is_fast)
|
||||
|
||||
# Check that Rust and Python align
|
||||
self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
|
||||
self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_max_length_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
|
||||
# Specific for
|
||||
kwargs = {}
|
||||
if tok_case.kwargs is not None:
|
||||
kwargs = dict(tok_case.kwargs)
|
||||
tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
|
||||
self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)
|
||||
|
||||
def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
|
||||
# Simple input
|
||||
s = "This is a simple input"
|
||||
@@ -595,27 +810,31 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
|
||||
]
|
||||
|
||||
# Simple input tests
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
||||
|
||||
# Simple input
|
||||
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, padding="max_length",
|
||||
)
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
||||
|
||||
# Pair input
|
||||
self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, pad_to_max_length=True)
|
||||
self.assertRaises(
|
||||
ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, padding="max_length",
|
||||
)
|
||||
|
||||
|
||||
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
|
||||
TOKENIZERS_CLASSES = frozenset(
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None)]
|
||||
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -53,6 +53,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
"<|endoftext|>",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
@@ -73,7 +74,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return GPT2TokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
@@ -118,3 +119,8 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
input_tokens = tokens + [rust_tokenizer.unk_token]
|
||||
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_pretokenized_inputs(self, *args, **kwargs):
|
||||
# It's very difficult to mix/test pretokenization with byte-level
|
||||
# And get both GPT2 and Roberta to work at the same time (mostly an issue of adding a space before the string)
|
||||
pass
|
||||
|
||||
@@ -51,10 +51,10 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = MarianTokenizer.from_pretrained(self.tmpdirname)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
|
||||
return MarianTokenizer.from_pretrained(self.tmpdirname, model_max_length=max_len, **kwargs)
|
||||
def get_tokenizer(self, **kwargs) -> MarianTokenizer:
|
||||
return MarianTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return (
|
||||
"This is a test",
|
||||
"This is a test",
|
||||
|
||||
@@ -64,7 +64,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return "lower newer", "lower newer"
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
|
||||
@@ -18,7 +18,7 @@ import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, RobertaTokenizer, RobertaTokenizerFast
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .utils import slow
|
||||
@@ -68,7 +68,11 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return RobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -56,7 +56,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs["lower_case"] = True
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "<unk> UNwanted , running"
|
||||
output_text = "<unk> unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
@@ -65,7 +65,7 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_input_output_texts(self):
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
Reference in New Issue
Block a user