Fix BatchEncoding.word_to_tokens for removed tokens (#7939)

This commit is contained in:
Anthony MOI
2020-10-23 10:29:37 -04:00
committed by GitHub
parent 4acfd1a8dc
commit 5e323017a4
2 changed files with 16 additions and 5 deletions

View File

@@ -364,7 +364,7 @@ class BatchEncoding(UserDict):
token_index = self._seq_len + token_index token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_word(token_index) return self._encodings[batch_index].token_to_word(token_index)
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan: def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]:
""" """
Get the encoded token span corresponding to a word in the sequence of the batch. Get the encoded token span corresponding to a word in the sequence of the batch.
@@ -391,8 +391,9 @@ class BatchEncoding(UserDict):
of the word in the sequence. of the word in the sequence.
Returns: Returns:
:class:`~transformers.tokenization_utils_base.TokenSpan` Optional :class:`~transformers.tokenization_utils_base.TokenSpan`
Span of tokens in the encoded sequence. Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond
to the word.
""" """
if not self._encodings: if not self._encodings:
@@ -406,7 +407,8 @@ class BatchEncoding(UserDict):
batch_index = self._batch_size + batch_index batch_index = self._batch_size + batch_index
if word_index < 0: if word_index < 0:
word_index = self._seq_len + word_index word_index = self._seq_len + word_index
return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index))) span = self._encodings[batch_index].word_to_tokens(word_index)
return TokenSpan(*span) if span is not None else None
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan: def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
""" """

View File

@@ -18,7 +18,7 @@ from typing import Callable, Optional
import numpy as np import numpy as np
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
from transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
@@ -142,6 +142,15 @@ class TokenizerUtilsTest(unittest.TestCase):
with self.subTest("Rust Tokenizer"): with self.subTest("Rust Tokenizer"):
self.assertTrue(tokenizer_r("Small example to_encode").is_fast) self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
@require_tokenizers
def test_batch_encoding_word_to_tokens(self):
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
encoded = tokenizer_r(["Test", "\xad", "test"], is_split_into_words=True)
self.assertEqual(encoded.word_to_tokens(0), TokenSpan(start=1, end=2))
self.assertEqual(encoded.word_to_tokens(1), None)
self.assertEqual(encoded.word_to_tokens(2), TokenSpan(start=2, end=3))
def test_batch_encoding_with_labels(self): def test_batch_encoding_with_labels(self):
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}) batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
tensor_batch = batch.convert_to_tensors(tensor_type="np") tensor_batch = batch.convert_to_tensors(tensor_type="np")