Fix BatchEncoding.word_to_tokens for removed tokens (#7939)
This commit is contained in:
@@ -364,7 +364,7 @@ class BatchEncoding(UserDict):
|
|||||||
token_index = self._seq_len + token_index
|
token_index = self._seq_len + token_index
|
||||||
return self._encodings[batch_index].token_to_word(token_index)
|
return self._encodings[batch_index].token_to_word(token_index)
|
||||||
|
|
||||||
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan:
|
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]:
|
||||||
"""
|
"""
|
||||||
Get the encoded token span corresponding to a word in the sequence of the batch.
|
Get the encoded token span corresponding to a word in the sequence of the batch.
|
||||||
|
|
||||||
@@ -391,8 +391,9 @@ class BatchEncoding(UserDict):
|
|||||||
of the word in the sequence.
|
of the word in the sequence.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`~transformers.tokenization_utils_base.TokenSpan`
|
Optional :class:`~transformers.tokenization_utils_base.TokenSpan`
|
||||||
Span of tokens in the encoded sequence.
|
Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond
|
||||||
|
to the word.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self._encodings:
|
if not self._encodings:
|
||||||
@@ -406,7 +407,8 @@ class BatchEncoding(UserDict):
|
|||||||
batch_index = self._batch_size + batch_index
|
batch_index = self._batch_size + batch_index
|
||||||
if word_index < 0:
|
if word_index < 0:
|
||||||
word_index = self._seq_len + word_index
|
word_index = self._seq_len + word_index
|
||||||
return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index)))
|
span = self._encodings[batch_index].word_to_tokens(word_index)
|
||||||
|
return TokenSpan(*span) if span is not None else None
|
||||||
|
|
||||||
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
|
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
|
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType, TokenSpan
|
||||||
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
|
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
|
||||||
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
@@ -142,6 +142,15 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
with self.subTest("Rust Tokenizer"):
|
with self.subTest("Rust Tokenizer"):
|
||||||
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
|
self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_batch_encoding_word_to_tokens(self):
|
||||||
|
tokenizer_r = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||||
|
encoded = tokenizer_r(["Test", "\xad", "test"], is_split_into_words=True)
|
||||||
|
|
||||||
|
self.assertEqual(encoded.word_to_tokens(0), TokenSpan(start=1, end=2))
|
||||||
|
self.assertEqual(encoded.word_to_tokens(1), None)
|
||||||
|
self.assertEqual(encoded.word_to_tokens(2), TokenSpan(start=2, end=3))
|
||||||
|
|
||||||
def test_batch_encoding_with_labels(self):
|
def test_batch_encoding_with_labels(self):
|
||||||
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
|
batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
|
||||||
tensor_batch = batch.convert_to_tensors(tensor_type="np")
|
tensor_batch = batch.convert_to_tensors(tensor_type="np")
|
||||||
|
|||||||
Reference in New Issue
Block a user