[cleanup] test_tokenization_common.py (#4390)
This commit is contained in:
@@ -36,9 +36,6 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "this is a test"
|
||||
output_text = "this is a test"
|
||||
|
||||
@@ -59,9 +59,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
|
||||
@@ -60,9 +60,6 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "こんにちは、世界。 \nこんばんは、世界。"
|
||||
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
|
||||
|
||||
@@ -22,12 +22,12 @@ from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Tuple, Union
|
||||
|
||||
from tests.utils import require_tf, require_torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import (
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
PreTrainedModel,
|
||||
TFPreTrainedModel,
|
||||
@@ -67,19 +67,24 @@ class TokenizerTesterMixin:
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
def get_input_output_texts(self) -> Tuple[str, str]:
|
||||
"""Feel free to overwrite"""
|
||||
# TODO: @property
|
||||
return (
|
||||
"This is a test",
|
||||
"This is a test",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
|
||||
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
|
||||
# to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
|
||||
# to the list of examples/ encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
|
||||
return [
|
||||
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
|
||||
for i in range(len(batch_encode_plus_sequences["input_ids"]))
|
||||
@@ -114,13 +119,13 @@ class TokenizerTesterMixin:
|
||||
|
||||
# Now let's start the test
|
||||
tokenizer = self.get_tokenizer(max_len=42)
|
||||
|
||||
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
|
||||
sample_text = "He is very happy, UNwant\u00E9d,running"
|
||||
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
|
||||
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
self.assertEqual(tokenizer.max_len, 42)
|
||||
@@ -128,6 +133,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(tokenizer.max_len, 43)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
@@ -253,7 +259,7 @@ class TokenizerTesterMixin:
|
||||
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
|
||||
assert special_token not in decoded
|
||||
|
||||
def test_required_methods_tokenizer(self):
|
||||
def test_internal_consistency(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_input_output_texts()
|
||||
|
||||
@@ -263,13 +269,12 @@ class TokenizerTesterMixin:
|
||||
self.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
self.assertIsInstance(text_2, str)
|
||||
|
||||
self.assertEqual(text_2, output_text)
|
||||
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
self.assertIsInstance(text_2, str)
|
||||
|
||||
def test_encode_decode_with_spaces(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -429,10 +434,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
def test_special_tokens_mask(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sequence_0 = "Encode this."
|
||||
sequence_1 = "This one too please."
|
||||
|
||||
# Testing single inputs
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
@@ -442,13 +444,13 @@ class TokenizerTesterMixin:
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
|
||||
filtered_sequence = [
|
||||
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
|
||||
]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing inputs pairs
|
||||
def test_special_tokens_mask_input_pairs(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequence_0 = "Encode this."
|
||||
sequence_1 = "This one too please."
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
@@ -464,7 +466,9 @@ class TokenizerTesterMixin:
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing with already existing special tokens
|
||||
def test_special_tokens_mask_already_has_special_tokens(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequence_0 = "Encode this."
|
||||
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
||||
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
@@ -514,13 +518,12 @@ class TokenizerTesterMixin:
|
||||
tokenizer.padding_side = "right"
|
||||
padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
|
||||
padded_sequence_right_length = len(padded_sequence_right)
|
||||
assert sequence_length == padded_sequence_right_length
|
||||
assert encoded_sequence == padded_sequence_right
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
|
||||
padded_sequence_left_length = len(padded_sequence_left)
|
||||
|
||||
assert sequence_length == padded_sequence_right_length
|
||||
assert encoded_sequence == padded_sequence_right
|
||||
assert sequence_length == padded_sequence_left_length
|
||||
assert encoded_sequence == padded_sequence_left
|
||||
|
||||
@@ -617,6 +620,9 @@ class TokenizerTesterMixin:
|
||||
self.assertIsInstance(vocab, dict)
|
||||
self.assertEqual(len(vocab), len(tokenizer))
|
||||
|
||||
def test_conversion_reversible(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
for word, ind in vocab.items():
|
||||
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
|
||||
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
|
||||
@@ -746,6 +752,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
@require_torch
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
|
||||
|
||||
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
|
||||
@@ -773,8 +780,10 @@ class TokenizerTesterMixin:
|
||||
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
|
||||
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
|
||||
# This should not fail
|
||||
model(**encoded_sequence)
|
||||
model(**batch_encoded_sequence)
|
||||
|
||||
with torch.no_grad(): # saves some time
|
||||
model(**encoded_sequence)
|
||||
model(**batch_encoded_sequence)
|
||||
|
||||
if self.test_rust_tokenizer:
|
||||
fast_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
@@ -24,9 +24,6 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DistilBertTokenizer
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
|
||||
@@ -64,13 +64,8 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
return "lower newer", "lower newer"
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
|
||||
|
||||
@@ -37,14 +37,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "This is a test"
|
||||
output_text = "This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
|
||||
|
||||
|
||||
@@ -65,9 +65,6 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "This is a test"
|
||||
output_text = "This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
@@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def big_tokenizer(self):
|
||||
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_easy_symbols(self):
|
||||
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
symbols = "Hello World!"
|
||||
original_tokenizer_encodings = [0, 35378, 6661, 38, 2]
|
||||
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer
|
||||
# xlmr.eval()
|
||||
# xlmr.encode(symbols)
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_hard_symbols(self):
|
||||
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
|
||||
original_tokenizer_encodings = [
|
||||
0,
|
||||
@@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# xlmr.eval()
|
||||
# xlmr.encode(symbols)
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@@ -37,14 +37,6 @@ class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "This is a test"
|
||||
output_text = "This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user