Merge pull request #1057 from huggingface/fixes

Add a few of typos corrections, bugs fixes and small improvements
This commit is contained in:
Thomas Wolf
2019-08-21 01:54:05 +02:00
committed by GitHub
13 changed files with 85 additions and 28 deletions

View File

@@ -193,6 +193,13 @@ class PreTrainedTokenizer(object):
cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
@@ -223,6 +230,8 @@ class PreTrainedTokenizer(object):
@classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
@@ -283,7 +292,7 @@ class PreTrainedTokenizer(object):
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")