Merge branch 'master' into squad-refactor
This commit is contained in:
@@ -22,6 +22,7 @@ import json
|
||||
import six
|
||||
import copy
|
||||
import itertools
|
||||
import re
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path, is_tf_available, is_torch_available
|
||||
@@ -263,6 +264,9 @@ class PreTrainedTokenizer(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
@@ -298,6 +302,7 @@ class PreTrainedTokenizer(object):
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
resume_download = kwargs.pop('resume_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
@@ -354,7 +359,7 @@ class PreTrainedTokenizer(object):
|
||||
"We assumed '{}' was a path or url to a directory containing vocabulary files "
|
||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path, ', '.join(s3_models),
|
||||
pretrained_model_name_or_path,
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values())))
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
@@ -364,7 +369,7 @@ class PreTrainedTokenizer(object):
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
msg = "Couldn't reach server at '{}' to download vocabulary files."
|
||||
@@ -389,7 +394,8 @@ class PreTrainedTokenizer(object):
|
||||
# Did we saved some inputs and kwargs to reload ?
|
||||
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
|
||||
if tokenizer_config_file is not None:
|
||||
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
saved_init_inputs = init_kwargs.pop('init_inputs', ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
@@ -414,7 +420,8 @@ class PreTrainedTokenizer(object):
|
||||
if args_name not in init_kwargs:
|
||||
init_kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
for key, value in special_tokens_map.items():
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = value
|
||||
@@ -428,7 +435,8 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
# Add supplementary tokens.
|
||||
if added_tokens_file is not None:
|
||||
added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
|
||||
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
|
||||
added_tok_encoder = json.load(added_tokens_handle)
|
||||
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
|
||||
tokenizer.added_tokens_encoder.update(added_tok_encoder)
|
||||
tokenizer.added_tokens_decoder.update(added_tok_decoder)
|
||||
@@ -524,6 +532,8 @@ class PreTrainedTokenizer(object):
|
||||
to_add_tokens = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
|
||||
if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
|
||||
token = token.lower()
|
||||
if token != self.unk_token and \
|
||||
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
|
||||
token not in to_add_tokens:
|
||||
@@ -621,6 +631,19 @@ class PreTrainedTokenizer(object):
|
||||
return_tokens_mapped_to_origin: (optional) Set to True to return the index of each token in the initial whitespace tokenization. (default False).
|
||||
**kwargs: passed to the child `self.tokenize()` method
|
||||
"""
|
||||
def lowercase_text(t):
|
||||
# convert non-special tokens to lowercase
|
||||
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
|
||||
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
|
||||
r'(.+?)'
|
||||
return re.sub(
|
||||
pattern,
|
||||
lambda m: m.groups()[0] or m.groups()[1].lower(),
|
||||
t)
|
||||
|
||||
if self.init_kwargs.get('do_lower_case', False):
|
||||
text = lowercase_text(text)
|
||||
|
||||
def split_on_token(tok, text):
|
||||
result = []
|
||||
split_text = text.split(tok)
|
||||
@@ -823,7 +846,6 @@ class PreTrainedTokenizer(object):
|
||||
``input_ids``: list of token ids to be fed to a model
|
||||
``token_type_ids``: list of token type ids to be fed to a model
|
||||
``attention_mask``: list of indices specifying which tokens should be attended to by the model
|
||||
|
||||
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||
@@ -904,15 +926,18 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
{
|
||||
input_ids: list[int],
|
||||
overflowing_tokens: list[int] if a ``max_length`` is specified, else None
|
||||
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True``
|
||||
token_type_ids: list[int] if return_token_type_ids is True (default)
|
||||
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
|
||||
}
|
||||
|
||||
With the fields:
|
||||
``input_ids``: list of tokens to be fed to a model
|
||||
``input_ids``: list of token ids to be fed to a model
|
||||
``token_type_ids``: list of token type ids to be fed to a model
|
||||
|
||||
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||
|
||||
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||
tokens and 1 specifying sequence tokens.
|
||||
"""
|
||||
@@ -921,23 +946,31 @@ class PreTrainedTokenizer(object):
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
encoded_inputs = {}
|
||||
|
||||
# Handle max sequence length
|
||||
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
|
||||
if max_length and total_len > max_length:
|
||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
|
||||
num_tokens_to_remove=total_len-max_length,
|
||||
truncation_strategy=truncation_strategy,
|
||||
stride=stride)
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
# Handle special_tokens
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||
special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0))
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
|
||||
# Prepare inputs as tensors if asked
|
||||
if return_tensors == 'tf' and is_tf_available():
|
||||
sequence = tf.constant([sequence])
|
||||
token_type_ids = tf.constant([token_type_ids])
|
||||
@@ -948,12 +981,15 @@ class PreTrainedTokenizer(object):
|
||||
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
|
||||
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
|
||||
if max_length and len(encoded_inputs["input_ids"]) > max_length:
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
|
||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
||||
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
@@ -995,7 +1031,7 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
elif return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
|
||||
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
|
||||
@@ -1039,7 +1075,6 @@ class PreTrainedTokenizer(object):
|
||||
return (ids, pair_ids, overflowing_tokens)
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
||||
logger.warning("This tokenizer does not make use of special tokens.")
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0) * [0]
|
||||
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
|
||||
@@ -1052,7 +1087,6 @@ class PreTrainedTokenizer(object):
|
||||
single sequence: <s> X </s>
|
||||
pair of sequences: <s> A </s></s> B </s>
|
||||
"""
|
||||
logger.warning("This tokenizer does not make use of special tokens. Input is returned with no modification.")
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0
|
||||
return token_ids_0 + token_ids_1
|
||||
|
||||
Reference in New Issue
Block a user