Merge branch 'master' into resumable_http
This commit is contained in:
@@ -21,6 +21,7 @@ import os
|
||||
import json
|
||||
import six
|
||||
import copy
|
||||
import itertools
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path, is_tf_available, is_torch_available
|
||||
@@ -516,6 +517,8 @@ class PreTrainedTokenizer(object):
|
||||
to_add_tokens = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
||||
if self.init_kwargs.get('do_lower_case', False):
|
||||
token = token.lower()
|
||||
if token != self.unk_token and \
|
||||
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
||||
token not in to_add_tokens:
|
||||
@@ -609,6 +612,9 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
Take care of added tokens.
|
||||
"""
|
||||
if self.init_kwargs.get('do_lower_case', False):
|
||||
text = text.lower()
|
||||
|
||||
def split_on_token(tok, text):
|
||||
result = []
|
||||
split_text = text.split(tok)
|
||||
@@ -645,9 +651,9 @@ class PreTrainedTokenizer(object):
|
||||
tokenized_text += [sub_text]
|
||||
text_list = tokenized_text
|
||||
|
||||
return sum((self._tokenize(token, **kwargs) if token not \
|
||||
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
|
||||
in self.added_tokens_encoder and token not in self.all_special_tokens \
|
||||
else [token] for token in tokenized_text), [])
|
||||
else [token] for token in tokenized_text)))
|
||||
|
||||
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
|
||||
tokenized_text = split_on_tokens(added_tokens, text)
|
||||
@@ -675,10 +681,6 @@ class PreTrainedTokenizer(object):
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
def _convert_token_to_id_with_added_voc(self, token):
|
||||
@@ -693,14 +695,14 @@ class PreTrainedTokenizer(object):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self,
|
||||
text,
|
||||
text_pair=None,
|
||||
add_special_tokens=False,
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncation_strategy='longest_first',
|
||||
return_tensors=None,
|
||||
**kwargs):
|
||||
text,
|
||||
text_pair=None,
|
||||
add_special_tokens=True,
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncation_strategy='longest_first',
|
||||
return_tensors=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
@@ -743,7 +745,7 @@ class PreTrainedTokenizer(object):
|
||||
def encode_plus(self,
|
||||
text,
|
||||
text_pair=None,
|
||||
add_special_tokens=False,
|
||||
add_special_tokens=True,
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncation_strategy='longest_first',
|
||||
@@ -798,7 +800,7 @@ class PreTrainedTokenizer(object):
|
||||
truncation_strategy=truncation_strategy,
|
||||
return_tensors=return_tensors)
|
||||
|
||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
|
||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
|
||||
truncation_strategy='longest_first', return_tensors=None):
|
||||
"""
|
||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
@@ -881,6 +883,11 @@ class PreTrainedTokenizer(object):
|
||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
||||
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.max_len))
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
|
||||
@@ -955,7 +962,7 @@ class PreTrainedTokenizer(object):
|
||||
special tokens for the model
|
||||
|
||||
Returns:
|
||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
||||
|
||||
@@ -1059,7 +1066,7 @@ class PreTrainedTokenizer(object):
|
||||
class attributes (cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = self.all_special_tokens
|
||||
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
|
||||
all_ids = self.convert_tokens_to_ids(all_toks)
|
||||
return all_ids
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user