fix GPT-2 and RoBERTa tests to be clean now
This commit is contained in:
@@ -31,17 +31,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
|
|
||||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"lo", "low", "er",
|
"\u0120", "\u0120l", "\u0120n",
|
||||||
"low", "lowest", "newer", "wider", "<unk>"]
|
"\u0120lo", "\u0120low", "er",
|
||||||
|
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"]
|
||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||||
|
|
||||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
with open(self.vocab_file, "w") as fp:
|
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||||
fp.write(json.dumps(vocab_tokens))
|
fp.write(json.dumps(vocab_tokens))
|
||||||
with open(self.merges_file, "w") as fp:
|
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
def get_tokenizer(self):
|
def get_tokenizer(self):
|
||||||
@@ -49,18 +50,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = u"lower newer"
|
input_text = u"lower newer"
|
||||||
output_text = u"lower<unk>newer"
|
output_text = u" lower newer"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||||
text = "lower"
|
text = "lower"
|
||||||
bpe_tokens = ["low", "er"]
|
bpe_tokens = ["\u0120low", "er"]
|
||||||
tokens = tokenizer.tokenize(text)
|
tokens = tokenizer.tokenize(text)
|
||||||
self.assertListEqual(tokens, bpe_tokens)
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
input_tokens = tokens + [tokenizer.unk_token]
|
input_tokens = tokens + [tokenizer.unk_token]
|
||||||
input_bpe_tokens = [13, 12, 17]
|
input_bpe_tokens = [14, 15, 19]
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
|||||||
@@ -30,17 +30,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
|
|
||||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
"lo", "low", "er",
|
"\u0120", "\u0120l", "\u0120n",
|
||||||
"low", "lowest", "newer", "wider", "<unk>"]
|
"\u0120lo", "\u0120low", "er",
|
||||||
|
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"]
|
||||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||||
|
|
||||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||||
with open(self.vocab_file, "w") as fp:
|
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||||
fp.write(json.dumps(vocab_tokens))
|
fp.write(json.dumps(vocab_tokens))
|
||||||
with open(self.merges_file, "w") as fp:
|
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||||
fp.write("\n".join(merges))
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
def get_tokenizer(self):
|
def get_tokenizer(self):
|
||||||
@@ -48,18 +49,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
|
|
||||||
def get_input_output_texts(self):
|
def get_input_output_texts(self):
|
||||||
input_text = u"lower newer"
|
input_text = u"lower newer"
|
||||||
output_text = u"lower<unk>newer"
|
output_text = u" lower newer"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||||
text = "lower"
|
text = "lower"
|
||||||
bpe_tokens = ["low", "er"]
|
bpe_tokens = ["\u0120low", "er"]
|
||||||
tokens = tokenizer.tokenize(text)
|
tokens = tokenizer.tokenize(text)
|
||||||
self.assertListEqual(tokens, bpe_tokens)
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
input_tokens = tokens + [tokenizer.unk_token]
|
input_tokens = tokens + [tokenizer.unk_token]
|
||||||
input_bpe_tokens = [13, 12, 17]
|
input_bpe_tokens = [14, 15, 19]
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class CommonTestCases:
|
|||||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
|
|
||||||
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||||
'pad_token': "<<<<<|||>|>>>>|>"}
|
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||||
vocab_size_3 = tokenizer.vocab_size
|
vocab_size_3 = tokenizer.vocab_size
|
||||||
all_size_3 = len(tokenizer)
|
all_size_3 = len(tokenizer)
|
||||||
@@ -129,7 +129,7 @@ class CommonTestCases:
|
|||||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||||
self.assertGreater(tokens[-2], tokens[-3])
|
self.assertGreater(tokens[-2], tokens[-3])
|
||||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||||
self.assertEqual(tokens[-2], tokenizer.eos_token_id)
|
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
|
||||||
def test_required_methods_tokenizer(self):
|
def test_required_methods_tokenizer(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user