Fix BatchEncoding.word_to_tokens for removed tokens (#7939)
This commit is contained in:
@@ -364,7 +364,7 @@ class BatchEncoding(UserDict):
|
||||
token_index = self._seq_len + token_index
|
||||
return self._encodings[batch_index].token_to_word(token_index)
|
||||
|
||||
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan:
|
||||
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]:
|
||||
"""
|
||||
Get the encoded token span corresponding to a word in the sequence of the batch.
|
||||
|
||||
@@ -391,8 +391,9 @@ class BatchEncoding(UserDict):
|
||||
of the word in the sequence.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.tokenization_utils_base.TokenSpan`
|
||||
Span of tokens in the encoded sequence.
|
||||
Optional :class:`~transformers.tokenization_utils_base.TokenSpan`
|
||||
Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond
|
||||
to the word.
|
||||
"""
|
||||
|
||||
if not self._encodings:
|
||||
@@ -406,7 +407,8 @@ class BatchEncoding(UserDict):
|
||||
batch_index = self._batch_size + batch_index
|
||||
if word_index < 0:
|
||||
word_index = self._seq_len + word_index
|
||||
return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index)))
|
||||
span = self._encodings[batch_index].word_to_tokens(word_index)
|
||||
return TokenSpan(*span) if span is not None else None
|
||||
|
||||
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user