Add Mirror Option for Downloads (#6679)
* Add Tuna Mirror for Downloads from China * format fix * Use preset instead of hardcoding URL * Fix * make style * update the mirror option doc * update the mirror
This commit is contained in:
@@ -337,7 +337,9 @@ class PretrainedConfig(object):
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
|
||||
@@ -141,6 +141,10 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
||||
|
||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||
PRESET_MIRROR_DICT = {
|
||||
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||
}
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -570,7 +574,7 @@ def is_remote_url(url_or_filename):
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
|
||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
|
||||
"""
|
||||
Resolve a model identifier, and a file name, to a HF-hosted url
|
||||
on either S3 or Cloudfront (a Content Delivery Network, or CDN).
|
||||
@@ -586,7 +590,13 @@ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
|
||||
are not shared between the two because the cached file's name contains
|
||||
a hash of the url.
|
||||
"""
|
||||
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
|
||||
endpoint = (
|
||||
PRESET_MIRROR_DICT.get(mirror, mirror)
|
||||
if mirror
|
||||
else CLOUDFRONT_DISTRIB_PREFIX
|
||||
if use_cdn
|
||||
else S3_BUCKET_PREFIX
|
||||
)
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
|
||||
@@ -139,7 +139,9 @@ class ModelCard:
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
model_card_file = pretrained_model_name_or_path
|
||||
else:
|
||||
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False)
|
||||
model_card_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
|
||||
)
|
||||
|
||||
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
|
||||
|
||||
@@ -484,6 +484,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
|
||||
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
|
||||
refer to the mirror site for more information.
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
@@ -522,6 +526,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
@@ -564,6 +569,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
pretrained_model_name_or_path,
|
||||
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
|
||||
use_cdn=use_cdn,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -784,6 +784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
|
||||
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
|
||||
refer to the mirror site for more information.
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
@@ -822,6 +826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
@@ -873,6 +878,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
pretrained_model_name_or_path,
|
||||
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
||||
use_cdn=use_cdn,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1483,7 +1483,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
full_file_name = None
|
||||
else:
|
||||
full_file_name = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=file_name, use_cdn=False
|
||||
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
|
||||
)
|
||||
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
Reference in New Issue
Block a user