Sequence IDS

This commit is contained in:
LysandreJik
2019-09-30 11:48:18 -04:00
parent 7c789c337d
commit 2f259b228e
6 changed files with 121 additions and 1 deletions

View File

@@ -826,7 +826,21 @@ class PreTrainedTokenizer(object):
or PyTorch torch.Tensor instead of a list of python integers.
Return:
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
A Dictionary of shape::
{
input_ids: list[int],
overflowing_tokens: list[int] if a ``max_length`` is specified, else None
sequence_ids: list[int] if ``add_special_tokens`` if set to ``True``
}
With the fields:
``input_ids``: list of tokens to be fed to a model
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
``sequence_ids``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
tokens and 1 specifying sequence tokens.
"""
pair = bool(pair_ids is not None)
len_ids = len(ids)
@@ -859,6 +873,7 @@ class PreTrainedTokenizer(object):
if add_special_tokens:
sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) if pair else [0] * len(sequence)
encoded_inputs["sequence_ids"] = self.get_sequence_ids(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
@@ -893,6 +908,9 @@ class PreTrainedTokenizer(object):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.