diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 08cee75d81..8ae30f2a48 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -24,7 +24,7 @@ import logging import os from io import open -from .file_utils import cached_path, CONFIG_NAME +from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url logger = logging.getLogger(__name__) @@ -131,8 +131,10 @@ class PretrainedConfig(object): config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] elif os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) - else: + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path + else: + config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME) # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, @@ -187,7 +189,7 @@ class PretrainedConfig(object): @classmethod def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" + """Constructs a `Config` from a json file of parameters.""" with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) diff --git a/transformers/file_utils.py b/transformers/file_utils.py index 68de4e6e2f..5fd5e2ee39 100644 --- a/transformers/file_utils.py +++ b/transformers/file_utils.py @@ -73,6 +73,8 @@ TF2_WEIGHTS_NAME = 'tf_model.h5' TF_WEIGHTS_NAME = 'model.ckpt' CONFIG_NAME = "config.json" +S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" + def is_torch_available(): return _torch_available @@ -103,6 +105,18 @@ else: return fn return docstring_decorator + +def is_remote_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ('http', 'https', 's3') + +def hf_bucket_url(identifier, postfix=None): + if postfix is None: + return "/".join((S3_BUCKET_PREFIX, identifier)) + else: + return "/".join((S3_BUCKET_PREFIX, identifier, postfix)) + + def url_to_filename(url, etag=None): """ Convert `url` into a hashed filename in a repeatable way. @@ -171,9 +185,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N if sys.version_info[0] == 3 and isinstance(cache_dir, Path): cache_dir = str(cache_dir) - parsed = urlparse(url_or_filename) - - if parsed.scheme in ('http', 'https', 's3'): + if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -181,7 +193,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=N elif os.path.exists(url_or_filename): # File, and it exists. return url_or_filename - elif parsed.scheme == '': + elif urlparse(url_or_filename).scheme == '': # File, but it doesn't exist. raise EnvironmentError("file {} not found".format(url_or_filename)) else: diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 9e7ca8d689..eac4252336 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -31,7 +31,8 @@ from torch.nn import CrossEntropyLoss from torch.nn import functional as F from .configuration_utils import PretrainedConfig -from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME +from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, + cached_path, hf_bucket_url, is_remote_url) logger = logging.getLogger(__name__) @@ -363,14 +364,15 @@ class PreTrainedModel(nn.Module): raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"], pretrained_model_name_or_path)) - elif os.path.isfile(pretrained_model_name_or_path): + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format( pretrained_model_name_or_path + ".index") archive_file = pretrained_model_name_or_path + ".index" else: - archive_file = pretrained_model_name_or_path + archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=WEIGHTS_NAME) + # todo do we want to support TF checkpoints here? # redirect to the cache, if necessary try: diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 68a767fe82..2b2cec0c15 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -25,7 +25,7 @@ import itertools import re from io import open -from .file_utils import cached_path, is_tf_available, is_torch_available +from .file_utils import cached_path, is_remote_url, hf_bucket_url, is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf @@ -327,12 +327,12 @@ class PreTrainedTokenizer(object): if os.path.isdir(pretrained_model_name_or_path): # If a directory is provided we look for the standard filenames full_file_name = os.path.join(pretrained_model_name_or_path, file_name) - else: + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) full_file_name = pretrained_model_name_or_path - if not os.path.exists(full_file_name): - logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) - full_file_name = None + else: + full_file_name = hf_bucket_url(pretrained_model_name_or_path, postfix=file_name) + vocab_files[file_id] = full_file_name # Look for the additional tokens files