Merge branch 'master' into resumable_http

This commit is contained in:
Thomas Wolf
2019-11-27 17:10:36 +01:00
committed by GitHub
131 changed files with 11504 additions and 921 deletions

View File

@@ -21,6 +21,7 @@ import os
import json
import six
import copy
import itertools
from io import open
from .file_utils import cached_path, is_tf_available, is_torch_available
@@ -516,6 +517,8 @@ class PreTrainedTokenizer(object):
to_add_tokens = []
for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
if self.init_kwargs.get('do_lower_case', False):
token = token.lower()
if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
token not in to_add_tokens:
@@ -609,6 +612,9 @@ class PreTrainedTokenizer(object):
Take care of added tokens.
"""
if self.init_kwargs.get('do_lower_case', False):
text = text.lower()
def split_on_token(tok, text):
result = []
split_text = text.split(tok)
@@ -645,9 +651,9 @@ class PreTrainedTokenizer(object):
tokenized_text += [sub_text]
text_list = tokenized_text
return sum((self._tokenize(token, **kwargs) if token not \
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \
else [token] for token in tokenized_text), [])
else [token] for token in tokenized_text)))
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
tokenized_text = split_on_tokens(added_tokens, text)
@@ -675,10 +681,6 @@ class PreTrainedTokenizer(object):
ids = []
for token in tokens:
ids.append(self._convert_token_to_id_with_added_voc(token))
if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
return ids
def _convert_token_to_id_with_added_voc(self, token):
@@ -693,14 +695,14 @@ class PreTrainedTokenizer(object):
raise NotImplementedError
def encode(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncation_strategy='longest_first',
return_tensors=None,
**kwargs):
text,
text_pair=None,
add_special_tokens=True,
max_length=None,
stride=0,
truncation_strategy='longest_first',
return_tensors=None,
**kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@@ -743,7 +745,7 @@ class PreTrainedTokenizer(object):
def encode_plus(self,
text,
text_pair=None,
add_special_tokens=False,
add_special_tokens=True,
max_length=None,
stride=0,
truncation_strategy='longest_first',
@@ -798,7 +800,7 @@ class PreTrainedTokenizer(object):
truncation_strategy=truncation_strategy,
return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
truncation_strategy='longest_first', return_tensors=None):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
@@ -881,6 +883,11 @@ class PreTrainedTokenizer(object):
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
@@ -955,7 +962,7 @@ class PreTrainedTokenizer(object):
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
@@ -1059,7 +1066,7 @@ class PreTrainedTokenizer(object):
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids
@staticmethod