update encode_plus - add truncation strategies
This commit is contained in:
@@ -754,23 +754,21 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def add_special_tokens_single_sequence(self, token_ids):
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Adds special tokens to a sequence for sequence classification tasks.
|
||||
An XLM sequence has the following format: [CLS] X [SEP]
|
||||
"""
|
||||
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
||||
|
||||
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||
"""
|
||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A RoBERTa sequence has the following format:
|
||||
single sequence: <s> X </s>
|
||||
pair of sequences: <s> A </s></s> B </s>
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||
@@ -778,30 +776,37 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
Args:
|
||||
token_ids_0: list of ids (must not contain special tokens)
|
||||
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
||||
for sequence pairs
|
||||
for sequence pairs
|
||||
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
||||
special tokens for the model
|
||||
|
||||
Returns:
|
||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||
"""
|
||||
|
||||
if special_tokens_present:
|
||||
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError("You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model.")
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1:
|
||||
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
|
||||
else:
|
||||
return [0] + ([1] * len(token_ids_0)) + [0]
|
||||
if token_ids_1 is not None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
An XLM sequence pair mask has the following format:
|
||||
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence
|
||||
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
|
||||
Reference in New Issue
Block a user