update encode_plus - add truncation strategies

This commit is contained in:
thomwolf
2019-10-04 17:38:38 -04:00
parent 92c0f2fb90
commit 6c1d0bc066
12 changed files with 222 additions and 193 deletions

View File

@@ -187,24 +187,21 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string
def add_special_tokens_single_sequence(self, token_ids):
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Adds special tokens to the a sequence for sequence classification tasks.
A BERT sequence has the following format: [CLS] X [SEP]
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A BERT sequence has the following format:
single sequence: [CLS] X [SEP]
pair of sequences: [CLS] A [SEP] B [SEP]
"""
return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
"""
sep = [self.sep_token_id]
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@@ -212,30 +209,37 @@ class BertTokenizer(PreTrainedTokenizer):
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
else:
return [0] + ([1] * len(token_ids_0)) + [0]
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, vocab_path):