Add Mirror Option for Downloads (#6679)

* Add Tuna Mirror for Downloads from China

* format fix

* Use preset instead of hardcoding URL

* Fix

* make style

* update the mirror option doc

* update the mirror
This commit is contained in:
Kevin Canwen Xu
2020-09-14 23:50:22 +08:00
committed by GitHub
parent e0e0675ac7
commit 90cde2e938
6 changed files with 31 additions and 5 deletions

View File

@@ -337,7 +337,9 @@ class PretrainedConfig(object):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
else: else:
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached

View File

@@ -141,6 +141,10 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
def is_torch_available(): def is_torch_available():
@@ -570,7 +574,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https") return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
""" """
Resolve a model identifier, and a file name, to a HF-hosted url Resolve a model identifier, and a file name, to a HF-hosted url
on either S3 or Cloudfront (a Content Delivery Network, or CDN). on either S3 or Cloudfront (a Content Delivery Network, or CDN).
@@ -586,7 +590,13 @@ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
are not shared between the two because the cached file's name contains are not shared between the two because the cached file's name contains
a hash of the url. a hash of the url.
""" """
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX endpoint = (
PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror
else CLOUDFRONT_DISTRIB_PREFIX
if use_cdn
else S3_BUCKET_PREFIX
)
legacy_format = "/" not in model_id legacy_format = "/" not in model_id
if legacy_format: if legacy_format:
return f"{endpoint}/{model_id}-{filename}" return f"{endpoint}/{model_id}-{filename}"

View File

@@ -139,7 +139,9 @@ class ModelCard:
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path model_card_file = pretrained_model_name_or_path
else: else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False) model_card_file = hf_bucket_url(
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME) model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)

View File

@@ -484,6 +484,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -522,6 +526,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
@@ -564,6 +569,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
use_cdn=use_cdn, use_cdn=use_cdn,
mirror=mirror,
) )
try: try:

View File

@@ -784,6 +784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
@@ -822,6 +826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
@@ -873,6 +878,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn, use_cdn=use_cdn,
mirror=mirror,
) )
try: try:

View File

@@ -1483,7 +1483,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
full_file_name = None full_file_name = None
else: else:
full_file_name = hf_bucket_url( full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, use_cdn=False pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
) )
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name