Add Mirror Option for Downloads (#6679)

* Add Tuna Mirror for Downloads from China

* format fix

* Use preset instead of hardcoding URL

* Fix

* make style

* update the mirror option doc

* update the mirror
This commit is contained in:
Kevin Canwen Xu
2020-09-14 23:50:22 +08:00
committed by GitHub
parent e0e0675ac7
commit 90cde2e938
6 changed files with 31 additions and 5 deletions

View File

@@ -337,7 +337,9 @@ class PretrainedConfig(object):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
)
try:
# Load from URL or cache if already cached

View File

@@ -141,6 +141,10 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
def is_torch_available():
@@ -570,7 +574,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
"""
Resolve a model identifier, and a file name, to a HF-hosted url
on either S3 or Cloudfront (a Content Delivery Network, or CDN).
@@ -586,7 +590,13 @@ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
are not shared between the two because the cached file's name contains
a hash of the url.
"""
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
endpoint = (
PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror
else CLOUDFRONT_DISTRIB_PREFIX
if use_cdn
else S3_BUCKET_PREFIX
)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"

View File

@@ -139,7 +139,9 @@ class ModelCard:
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False)
model_card_file = hf_bucket_url(
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)

View File

@@ -484,6 +484,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -522,6 +526,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
@@ -564,6 +569,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
use_cdn=use_cdn,
mirror=mirror,
)
try:

View File

@@ -784,6 +784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -822,6 +826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
@@ -873,6 +878,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn,
mirror=mirror,
)
try:

View File

@@ -1483,7 +1483,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, use_cdn=False
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
)
vocab_files[file_id] = full_file_name