Allow from_pretrained to take a remote identifier
This commit is contained in:
@@ -24,7 +24,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path, CONFIG_NAME
|
from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -131,8 +131,10 @@ class PretrainedConfig(object):
|
|||||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
else:
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
|
else:
|
||||||
|
config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME)
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
|
||||||
@@ -187,7 +189,7 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_file(cls, json_file):
|
def from_json_file(cls, json_file):
|
||||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
"""Constructs a `Config` from a json file of parameters."""
|
||||||
with open(json_file, "r", encoding='utf-8') as reader:
|
with open(json_file, "r", encoding='utf-8') as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
return cls.from_dict(json.loads(text))
|
return cls.from_dict(json.loads(text))
|
||||||
|
|||||||
@@ -73,6 +73,8 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
|
|||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
|
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
@@ -103,6 +105,18 @@ else:
|
|||||||
return fn
|
return fn
|
||||||
return docstring_decorator
|
return docstring_decorator
|
||||||
|
|
||||||
|
|
||||||
|
def is_remote_url(url_or_filename):
|
||||||
|
parsed = urlparse(url_or_filename)
|
||||||
|
return parsed.scheme in ('http', 'https', 's3')
|
||||||
|
|
||||||
|
def hf_bucket_url(identifier, postfix=None):
|
||||||
|
if postfix is None:
|
||||||
|
return "/".join((S3_BUCKET_PREFIX, identifier))
|
||||||
|
else:
|
||||||
|
return "/".join((S3_BUCKET_PREFIX, identifier, postfix))
|
||||||
|
|
||||||
|
|
||||||
def url_to_filename(url, etag=None):
|
def url_to_filename(url, etag=None):
|
||||||
"""
|
"""
|
||||||
Convert `url` into a hashed filename in a repeatable way.
|
Convert `url` into a hashed filename in a repeatable way.
|
||||||
@@ -171,9 +185,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
|
|||||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
|
|
||||||
parsed = urlparse(url_or_filename)
|
if is_remote_url(url_or_filename):
|
||||||
|
|
||||||
if parsed.scheme in ('http', 'https', 's3'):
|
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
# URL, so get it from the cache (downloading if necessary)
|
||||||
return get_from_cache(url_or_filename, cache_dir=cache_dir,
|
return get_from_cache(url_or_filename, cache_dir=cache_dir,
|
||||||
force_download=force_download, proxies=proxies,
|
force_download=force_download, proxies=proxies,
|
||||||
@@ -181,7 +193,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N
|
|||||||
elif os.path.exists(url_or_filename):
|
elif os.path.exists(url_or_filename):
|
||||||
# File, and it exists.
|
# File, and it exists.
|
||||||
return url_or_filename
|
return url_or_filename
|
||||||
elif parsed.scheme == '':
|
elif urlparse(url_or_filename).scheme == '':
|
||||||
# File, but it doesn't exist.
|
# File, but it doesn't exist.
|
||||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
|
||||||
|
cached_path, hf_bucket_url, is_remote_url)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -363,14 +364,15 @@ class PreTrainedModel(nn.Module):
|
|||||||
raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
|
raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
|
||||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
||||||
pretrained_model_name_or_path))
|
pretrained_model_name_or_path))
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
||||||
pretrained_model_name_or_path + ".index")
|
pretrained_model_name_or_path + ".index")
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME)
|
||||||
|
# todo do we want to support TF checkpoints here?
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import itertools
|
|||||||
import re
|
import re
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path, is_tf_available, is_torch_available
|
from .file_utils import cached_path, is_remote_url, hf_bucket_url, is_tf_available, is_torch_available
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -327,12 +327,12 @@ class PreTrainedTokenizer(object):
|
|||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
# If a directory is provided we look for the standard filenames
|
# If a directory is provided we look for the standard filenames
|
||||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||||
else:
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
|
# If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
|
||||||
full_file_name = pretrained_model_name_or_path
|
full_file_name = pretrained_model_name_or_path
|
||||||
if not os.path.exists(full_file_name):
|
else:
|
||||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name)
|
||||||
full_file_name = None
|
|
||||||
vocab_files[file_id] = full_file_name
|
vocab_files[file_id] = full_file_name
|
||||||
|
|
||||||
# Look for the additional tokens files
|
# Look for the additional tokens files
|
||||||
|
|||||||
Reference in New Issue
Block a user