Fix DeBERTa token_type_ids (#17082)

This commit is contained in:
Patrick Deutschmann
2022-05-04 18:23:37 +02:00
committed by GitHub
parent 279bc5849b
commit 870e6f29a6
4 changed files with 10 additions and 4 deletions

View File

@@ -88,6 +88,12 @@ class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_token_type_ids(self):
tokenizer = self.get_tokenizer()
tokd = tokenizer("Hello", "World")
expected_token_type_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
self.assertListEqual(tokd["token_type_ids"], expected_token_type_ids)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")