Fix BatchEncoding.word_to_tokens for removed tokens (#7939)

This commit is contained in:
Anthony MOI
2020-10-23 10:29:37 -04:00
committed by GitHub
parent 4acfd1a8dc
commit 5e323017a4
2 changed files with 16 additions and 5 deletions

View File

@@ -364,7 +364,7 @@ class BatchEncoding(UserDict):
token_index = self._seq_len + token_index
return self._encodings[batch_index].token_to_word(token_index)
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> TokenSpan:
def word_to_tokens(self, batch_or_word_index: int, word_index: Optional[int] = None) -> Optional[TokenSpan]:
"""
Get the encoded token span corresponding to a word in the sequence of the batch.
@@ -391,8 +391,9 @@ class BatchEncoding(UserDict):
of the word in the sequence.
Returns:
:class:`~transformers.tokenization_utils_base.TokenSpan`
Span of tokens in the encoded sequence.
Optional :class:`~transformers.tokenization_utils_base.TokenSpan`
Span of tokens in the encoded sequence. Returns :obj:`None` if no tokens correspond
to the word.
"""
if not self._encodings:
@@ -406,7 +407,8 @@ class BatchEncoding(UserDict):
batch_index = self._batch_size + batch_index
if word_index < 0:
word_index = self._seq_len + word_index
return TokenSpan(*(self._encodings[batch_index].word_to_tokens(word_index)))
span = self._encodings[batch_index].word_to_tokens(word_index)
return TokenSpan(*span) if span is not None else None
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
"""