update encode_plus - add truncation strategies

This commit is contained in:
thomwolf
2019-10-04 17:38:38 -04:00
parent 92c0f2fb90
commit 6c1d0bc066
12 changed files with 222 additions and 193 deletions

View File

@@ -84,23 +84,21 @@ class RobertaTokenizer(GPT2Tokenizer):
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
def add_special_tokens_single_sequence(self, token_ids):
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Adds special tokens to a sequence for sequence classification tasks.
A RoBERTa sequence has the following format: <s> X </s>
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A RoBERTa sequence has the following format:
single sequence: <s> X </s>
pair of sequences: <s> A </s></s> B </s>
"""
return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
"""
sep = [self.sep_token_id]
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@@ -108,28 +106,35 @@ class RobertaTokenizer(GPT2Tokenizer):
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0]
else:
return [0] + ([1] * len(token_ids_0)) + [0]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A RoBERTa sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1]