Sequence IDS

This commit is contained in:
LysandreJik
2019-09-30 11:48:18 -04:00
parent 7c789c337d
commit 2f259b228e
6 changed files with 121 additions and 1 deletions

View File

@@ -770,6 +770,24 @@ class XLMTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
else:
return [0] + ([1] * len(token_ids_0)) + [0]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.