Merge pull request #1057 from huggingface/fixes
Add a few of typos corrections, bugs fixes and small improvements
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -128,3 +128,5 @@ proc_data
|
|||||||
# examples
|
# examples
|
||||||
runs
|
runs
|
||||||
examples/runs
|
examples/runs
|
||||||
|
|
||||||
|
data
|
||||||
@@ -72,16 +72,16 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||||
| | | | XLNet Large English model |
|
| | | | XLNet Large English model |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 1024-hidden, 8-heads |
|
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 2048-hidden, 16-heads |
|
||||||
| | | | XLM English model |
|
| | | | XLM English model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-German Multi-language model |
|
| | | | XLM English-German Multi-language model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-enfr-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-French Multi-language model |
|
| | | | XLM English-French Multi-language model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-enro-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-enro-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-Romanian Multi-language model |
|
| | | | XLM English-Romanian Multi-language model |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||||
@@ -93,7 +93,7 @@ Here is the full list of the currently provided pretrained models together with
|
|||||||
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English model trained with CLM (Causal Language Modeling) |
|
| | | | XLM English model trained with CLM (Causal Language Modeling) |
|
||||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| | ``xlm-clm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
| | ``xlm-clm-ende-1024`` | | 6-layer, 1024-hidden, 8-heads |
|
||||||
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
|
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
|
||||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||||
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
|
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ from hashlib import sha256
|
|||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import requests
|
from botocore.config import Config
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -93,12 +94,15 @@ def filename_to_url(filename, cache_dir=None):
|
|||||||
return url, etag
|
return url, etag
|
||||||
|
|
||||||
|
|
||||||
def cached_path(url_or_filename, cache_dir=None):
|
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
|
||||||
"""
|
"""
|
||||||
Given something that might be a URL (or might be a local path),
|
Given something that might be a URL (or might be a local path),
|
||||||
determine which. If it's a URL, download the file and cache it, and
|
determine which. If it's a URL, download the file and cache it, and
|
||||||
return the path to the cached file. If it's already a local path,
|
return the path to the cached file. If it's already a local path,
|
||||||
make sure the file exists and then return the path.
|
make sure the file exists and then return the path.
|
||||||
|
Args:
|
||||||
|
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
||||||
|
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||||
@@ -111,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None):
|
|||||||
|
|
||||||
if parsed.scheme in ('http', 'https', 's3'):
|
if parsed.scheme in ('http', 'https', 's3'):
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
# URL, so get it from the cache (downloading if necessary)
|
||||||
return get_from_cache(url_or_filename, cache_dir)
|
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
elif os.path.exists(url_or_filename):
|
elif os.path.exists(url_or_filename):
|
||||||
# File, and it exists.
|
# File, and it exists.
|
||||||
return url_or_filename
|
return url_or_filename
|
||||||
@@ -156,24 +160,24 @@ def s3_request(func):
|
|||||||
|
|
||||||
|
|
||||||
@s3_request
|
@s3_request
|
||||||
def s3_etag(url):
|
def s3_etag(url, proxies=None):
|
||||||
"""Check ETag on S3 object."""
|
"""Check ETag on S3 object."""
|
||||||
s3_resource = boto3.resource("s3")
|
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
bucket_name, s3_path = split_s3_path(url)
|
||||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
s3_object = s3_resource.Object(bucket_name, s3_path)
|
||||||
return s3_object.e_tag
|
return s3_object.e_tag
|
||||||
|
|
||||||
|
|
||||||
@s3_request
|
@s3_request
|
||||||
def s3_get(url, temp_file):
|
def s3_get(url, temp_file, proxies=None):
|
||||||
"""Pull a file directly from S3."""
|
"""Pull a file directly from S3."""
|
||||||
s3_resource = boto3.resource("s3")
|
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
bucket_name, s3_path = split_s3_path(url)
|
||||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||||
|
|
||||||
|
|
||||||
def http_get(url, temp_file):
|
def http_get(url, temp_file, proxies=None):
|
||||||
req = requests.get(url, stream=True)
|
req = requests.get(url, stream=True, proxies=proxies)
|
||||||
content_length = req.headers.get('Content-Length')
|
content_length = req.headers.get('Content-Length')
|
||||||
total = int(content_length) if content_length is not None else None
|
total = int(content_length) if content_length is not None else None
|
||||||
progress = tqdm(unit="B", total=total)
|
progress = tqdm(unit="B", total=total)
|
||||||
@@ -184,7 +188,7 @@ def http_get(url, temp_file):
|
|||||||
progress.close()
|
progress.close()
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(url, cache_dir=None):
|
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding dataset in the local cache.
|
Given a URL, look for the corresponding dataset in the local cache.
|
||||||
If it's not there, download it. Then return the path to the cached file.
|
If it's not there, download it. Then return the path to the cached file.
|
||||||
@@ -201,10 +205,10 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
|
|
||||||
# Get eTag to add to filename, if it exists.
|
# Get eTag to add to filename, if it exists.
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
etag = s3_etag(url)
|
etag = s3_etag(url, proxies=proxies)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
response = requests.head(url, allow_redirects=True)
|
response = requests.head(url, allow_redirects=True, proxies=proxies)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
etag = None
|
etag = None
|
||||||
else:
|
else:
|
||||||
@@ -227,17 +231,17 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
if matching_files:
|
if matching_files:
|
||||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||||
|
|
||||||
if not os.path.exists(cache_path):
|
if not os.path.exists(cache_path) or force_download:
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
# Download to temporary file, then copy to cache dir once finished.
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||||
with tempfile.NamedTemporaryFile() as temp_file:
|
with tempfile.NamedTemporaryFile() as temp_file:
|
||||||
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||||
|
|
||||||
# GET file object
|
# GET file object
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
s3_get(url, temp_file)
|
s3_get(url, temp_file, proxies=proxies)
|
||||||
else:
|
else:
|
||||||
http_get(url, temp_file)
|
http_get(url, temp_file, proxies=proxies)
|
||||||
|
|
||||||
# we are copying the file before closing it, so flush to avoid truncation
|
# we are copying the file before closing it, so flush to avoid truncation
|
||||||
temp_file.flush()
|
temp_file.flush()
|
||||||
|
|||||||
@@ -600,6 +600,9 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
``token_type_ids: 0 0 0 0 0 0 0``
|
``token_type_ids: 0 0 0 0 0 0 0``
|
||||||
|
|
||||||
|
Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
|
||||||
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -390,6 +390,8 @@ GPT2_START_DOCSTRING = r""" OpenAI GPT-2 model was proposed in
|
|||||||
GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
GPT2_INPUTS_DOCSTRING = r""" Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -404,6 +404,8 @@ OPENAI_GPT_START_DOCSTRING = r""" OpenAI GPT model was proposed in
|
|||||||
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
OPENAI_GPT_INPUTS_DOCSTRING = r""" Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -110,6 +110,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
|
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
|
||||||
the ``add_special_tokens`` parameter set to ``True``.
|
the ``add_special_tokens`` parameter set to ``True``.
|
||||||
|
|
||||||
|
RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
|
|||||||
@@ -936,6 +936,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
|||||||
Inputs:
|
Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
|
||||||
|
the right or on the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -125,6 +125,13 @@ class PretrainedConfig(object):
|
|||||||
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
||||||
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
return_unused_kwargs: (`optional`) bool:
|
return_unused_kwargs: (`optional`) bool:
|
||||||
|
|
||||||
- If False, then this function returns just the final configuration object.
|
- If False, then this function returns just the final configuration object.
|
||||||
@@ -146,6 +153,8 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
|
force_download = kwargs.pop('force_download', False)
|
||||||
|
proxies = kwargs.pop('proxies', None)
|
||||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||||
|
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
@@ -156,7 +165,7 @@ class PretrainedConfig(object):
|
|||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -400,6 +409,13 @@ class PreTrainedModel(nn.Module):
|
|||||||
Path to a directory in which a downloaded pre-trained model
|
Path to a directory in which a downloaded pre-trained model
|
||||||
configuration should be cached if the standard cache should not be used.
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
output_loading_info: (`optional`) boolean:
|
output_loading_info: (`optional`) boolean:
|
||||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
|
||||||
@@ -424,6 +440,8 @@ class PreTrainedModel(nn.Module):
|
|||||||
state_dict = kwargs.pop('state_dict', None)
|
state_dict = kwargs.pop('state_dict', None)
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
from_tf = kwargs.pop('from_tf', False)
|
from_tf = kwargs.pop('from_tf', False)
|
||||||
|
force_download = kwargs.pop('force_download', False)
|
||||||
|
proxies = kwargs.pop('proxies', None)
|
||||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
@@ -431,6 +449,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
config, model_kwargs = cls.config_class.from_pretrained(
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args,
|
pretrained_model_name_or_path, *model_args,
|
||||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||||
|
force_download=force_download,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -453,7 +472,7 @@ class PreTrainedModel(nn.Module):
|
|||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
@@ -424,6 +424,10 @@ XLM_INPUTS_DOCSTRING = r"""
|
|||||||
Inputs:
|
Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
XLM is a model with absolute position embeddings so it's usually advised to pad the inputs on
|
||||||
|
the right rather than the left.
|
||||||
|
|
||||||
Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
@@ -436,8 +440,10 @@ XLM_INPUTS_DOCSTRING = r"""
|
|||||||
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
|
||||||
**langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
A parallel sequence of tokens to be used to indicate the language of each token in the input.
|
A parallel sequence of tokens to be used to indicate the language of each token in the input.
|
||||||
Indices are selected in the pre-trained language vocabulary,
|
Indices are languages ids which can be obtained from the language names by using two conversion mappings
|
||||||
i.e. in the range ``[0, config.n_langs - 1[``.
|
provided in the configuration of the model (only provided for multilingual models).
|
||||||
|
More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
|
||||||
|
the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
|
||||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Mask to avoid performing attention on padding token indices.
|
Mask to avoid performing attention on padding token indices.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
|
|||||||
@@ -655,6 +655,8 @@ XLNET_INPUTS_DOCSTRING = r"""
|
|||||||
Inputs:
|
Inputs:
|
||||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
XLNet is a model with relative position embeddings so you can either pad the inputs on
|
||||||
|
the right or on the left.
|
||||||
Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`.
|
Indices can be obtained using :class:`pytorch_transformers.XLNetTokenizer`.
|
||||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||||
|
|||||||
@@ -187,6 +187,8 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
index = 0
|
index = 0
|
||||||
if os.path.isdir(vocab_path):
|
if os.path.isdir(vocab_path):
|
||||||
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
|
||||||
|
else:
|
||||||
|
vocab_file = vocab_path
|
||||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
if index != token_index:
|
if index != token_index:
|
||||||
|
|||||||
@@ -193,6 +193,13 @@ class PreTrainedTokenizer(object):
|
|||||||
cache_dir: (`optional`) string:
|
cache_dir: (`optional`) string:
|
||||||
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
||||||
|
|
||||||
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
||||||
@@ -223,6 +230,8 @@ class PreTrainedTokenizer(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||||
cache_dir = kwargs.pop('cache_dir', None)
|
cache_dir = kwargs.pop('cache_dir', None)
|
||||||
|
force_download = kwargs.pop('force_download', False)
|
||||||
|
proxies = kwargs.pop('proxies', None)
|
||||||
|
|
||||||
s3_models = list(cls.max_model_input_sizes.keys())
|
s3_models = list(cls.max_model_input_sizes.keys())
|
||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
@@ -283,7 +292,7 @@ class PreTrainedTokenizer(object):
|
|||||||
if file_path is None:
|
if file_path is None:
|
||||||
resolved_vocab_files[file_id] = None
|
resolved_vocab_files[file_id] = None
|
||||||
else:
|
else:
|
||||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
|
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in s3_models:
|
if pretrained_model_name_or_path in s3_models:
|
||||||
logger.error("Couldn't reach server to download vocabulary.")
|
logger.error("Couldn't reach server to download vocabulary.")
|
||||||
|
|||||||
Reference in New Issue
Block a user