Fix DeBERTa token_type_ids (#17082)
This commit is contained in:
committed by
GitHub
parent
279bc5849b
commit
870e6f29a6
@@ -88,6 +88,12 @@ class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_token_type_ids(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokd = tokenizer("Hello", "World")
|
||||
expected_token_type_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
|
||||
self.assertListEqual(tokd["token_type_ids"], expected_token_type_ids)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
|
||||
|
||||
Reference in New Issue
Block a user