From 5e323017a4f62c628f1146cb86362ca7d8bf32c4 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 23 Oct 2020 10:29:37 -0400 Subject: [PATCH] Fix BatchEncoding.word_to_tokens for removed tokens (#7939) --- src/transformers/tokenization_utils_base.py | 10 ++++++---- tests/test_tokenization_utils.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 61f10228ea..d479186c4b 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -364,7 +364,7 @@ class BatchEncoding(UserDict): token_index = self._seq_len + token_index return self._encodings[batch_index].token_to_word(token_index) - def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan: + def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]: """ Get the encoded token span corresponding to a word in the sequence of the batch. @@ -391,8 +391,9 @@ class BatchEncoding(UserDict): of the word in the sequence. Returns: - :class:`~transformers.tokenization_utils_base.TokenSpan` - Span of tokens in the encoded sequence. + Optional :class:`~transformers.tokenization_utils_base.TokenSpan` + Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond + to the word. """ if not self._encodings: @@ -406,7 +407,8 @@ class BatchEncoding(UserDict): batch_index = self._batch_size + batch_index if word_index < 0: word_index = self._seq_len + word_index - return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index))) + span = self._encodings[batch_index].word_to_tokens(word_index) + return TokenSpan(*span) if span is not None else None def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: """ diff --git a/tests/test_tokenization_utils.py b/tests/test_tokenization_utils.py index d2b1e69f00..3bc09d2f0f 100644 --- a/tests/test_tokenization_utils.py +++ b/tests/test_tokenization_utils.py @@ -18,7 +18,7 @@ from typing import Callable, Optional import numpy as np -from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType +from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow from transformers.tokenization_gpt2 import GPT2Tokenizer @@ -142,6 +142,15 @@ class TokenizerUtilsTest(unittest.TestCase): with self.subTest("Rust Tokenizer"): self.assertTrue(tokenizer_r("Small example to_encode").is_fast) + @require_tokenizers + def test_batch_encoding_word_to_tokens(self): + tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased") + encoded = tokenizer_r(["Test", "\xad", "test"], is_split_into_words=True) + + self.assertEqual(encoded.word_to_tokens(0), TokenSpan(start=1, end=2)) + self.assertEqual(encoded.word_to_tokens(1), None) + self.assertEqual(encoded.word_to_tokens(2), TokenSpan(start=2, end=3)) + def test_batch_encoding_with_labels(self): batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}) tensor_batch = batch.convert_to_tensors(tensor_type="np")