better error message

This commit is contained in:
thomwolf
2019-06-18 10:51:31 +02:00
parent 382e2d1e50
commit 64e0adda81
4 changed files with 76 additions and 13 deletions

View File

@@ -658,7 +658,6 @@ class BertPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
@@ -673,6 +672,22 @@ class BertPreTrainedModel(nn.Module):
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))