Roberta tokenization + fixed tests (py3 + py2).
This commit is contained in:
@@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def test_inference_masked_lm(self):
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 50265))
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape
|
||||
)
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[33.8843, -4.3107, 22.7779],
|
||||
[4.6533, -2.8099, 13.6252],
|
||||
[1.8222, -3.6898, 8.8600]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
# @pytest.mark.slow
|
||||
def test_inference_no_head(self):
|
||||
model = RobertaModel.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[-0.0231, 0.0782, 0.0074],
|
||||
[-0.1854, 0.0539, -0.0174],
|
||||
[0.0548, 0.0799, 0.1687]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTest.RobertaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)
|
||||
@@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
class RobertaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# @pytest.mark.slow
|
||||
@pytest.mark.slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
|
||||
@@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase):
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
# @pytest.mark.slow
|
||||
@pytest.mark.slow
|
||||
def test_inference_no_head(self):
|
||||
model = RobertaModel.from_pretrained('roberta-base')
|
||||
|
||||
|
||||
@@ -18,8 +18,7 @@ import os
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, DICT_FILES_NAMES
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
|
||||
@@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
bpe_tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
return RobertaTokenizer.from_pretrained("roberta-base", bpe_tokenizer=bpe_tokenizer)
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
@@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [0, 4, 12, 176, 2]
|
||||
tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
input_bpe_tokens = [13, 12, 17]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user