[cleanup] test_tokenization_common.py (#4390)

This commit is contained in:
Sam Shleifer
2020-05-19 10:46:55 -04:00
committed by GitHub
parent 8f1d047148
commit 07dd7c2fd8
13 changed files with 62 additions and 98 deletions

View File

@@ -36,9 +36,6 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "this is a test"
output_text = "this is a test"

View File

@@ -59,9 +59,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

View File

@@ -60,9 +60,6 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"

View File

@@ -22,12 +22,12 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Tuple, Union
from tests.utils import require_tf, require_torch
from transformers import PreTrainedTokenizer
if TYPE_CHECKING:
from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
PreTrainedModel,
TFPreTrainedModel,
@@ -67,19 +67,24 @@ class TokenizerTesterMixin:
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_tokenizer(self, **kwargs):
raise NotImplementedError
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
raise NotImplementedError
def get_input_output_texts(self):
raise NotImplementedError
def get_input_output_texts(self) -> Tuple[str, str]:
"""Feel free to overwrite"""
# TODO: @property
return (
"This is a test",
"This is a test",
)
@staticmethod
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
# to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
# to the list of examples/ encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
return [
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
for i in range(len(batch_encode_plus_sequences["input_ids"]))
@@ -114,13 +119,13 @@ class TokenizerTesterMixin:
# Now let's start the test
tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
sample_text = "He is very happy, UNwant\u00E9d,running"
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
tokenizer.save_pretrained(self.tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42)
@@ -128,6 +133,7 @@ class TokenizerTesterMixin:
self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer)
@@ -253,7 +259,7 @@ class TokenizerTesterMixin:
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
assert special_token not in decoded
def test_required_methods_tokenizer(self):
def test_internal_consistency(self):
tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts()
@@ -263,13 +269,12 @@ class TokenizerTesterMixin:
self.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
self.assertNotEqual(len(tokens_2), 0)
text_2 = tokenizer.decode(ids)
self.assertIsInstance(text_2, str)
self.assertEqual(text_2, output_text)
self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, str)
def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer()
@@ -429,10 +434,7 @@ class TokenizerTesterMixin:
def test_special_tokens_mask(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
@@ -442,13 +444,13 @@ class TokenizerTesterMixin:
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None]
filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs
def test_special_tokens_mask_input_pairs(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
@@ -464,7 +466,9 @@ class TokenizerTesterMixin:
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing with already existing special tokens
def test_special_tokens_mask_already_has_special_tokens(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
encoded_sequence_dict = tokenizer.encode_plus(
@@ -514,13 +518,12 @@ class TokenizerTesterMixin:
tokenizer.padding_side = "right"
padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_right_length = len(padded_sequence_right)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
tokenizer.padding_side = "left"
padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_left_length = len(padded_sequence_left)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
assert sequence_length == padded_sequence_left_length
assert encoded_sequence == padded_sequence_left
@@ -617,6 +620,9 @@ class TokenizerTesterMixin:
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
def test_conversion_reversible(self):
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
@@ -746,6 +752,7 @@ class TokenizerTesterMixin:
@require_torch
def test_torch_encode_plus_sent_to_model(self):
import torch
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
@@ -773,8 +780,10 @@ class TokenizerTesterMixin:
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail
model(**encoded_sequence)
model(**batch_encoded_sequence)
with torch.no_grad(): # saves some time
model(**encoded_sequence)
model(**batch_encoded_sequence)
if self.test_rust_tokenizer:
fast_tokenizer = self.get_rust_tokenizer()

View File

@@ -24,9 +24,6 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer
def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

View File

@@ -64,13 +64,8 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
return "lower newer", "lower newer"
def test_full_tokenizer(self):
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)

View File

@@ -37,14 +37,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB)

View File

@@ -65,9 +65,6 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "lower newer"
output_text = "lower newer"

View File

@@ -17,6 +17,7 @@
import os
import unittest
from transformers.file_utils import cached_property
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
from .test_tokenization_common import TokenizerTesterMixin
@@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
@@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
],
)
@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
@slow
def test_tokenization_base_easy_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = "Hello World!"
original_tokenizer_encodings = [0, 35378, 6661, 38, 2]
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer
# xlmr.eval()
# xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
def test_tokenization_base_hard_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
original_tokenizer_encodings = [
0,
@@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# xlmr.eval()
# xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))

View File

@@ -37,14 +37,6 @@ class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)