From 64e0adda81cc8fca8de9b9a3639d02925b8fdffe Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 10:51:31 +0200 Subject: [PATCH] better error message --- pytorch_pretrained_bert/modeling.py | 17 ++++++++++- pytorch_pretrained_bert/modeling_gpt2.py | 22 +++++++++++++-- pytorch_pretrained_bert/modeling_openai.py | 22 +++++++++++++-- .../modeling_transfo_xl.py | 28 +++++++++++++++---- 4 files changed, 76 insertions(+), 13 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 006e6a1c73..25f9fe79cf 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -658,7 +658,6 @@ class BertPreTrainedModel(nn.Module): # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: logger.error( @@ -673,6 +672,22 @@ class BertPreTrainedModel(nn.Module): ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file)) return None + try: + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP: + logger.error( + "Couldn't reach server at '{}' to download pretrained model configuration file.".format( + config_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), + config_file)) + return None if resolved_archive_file == archive_file and resolved_config_file == config_file: logger.info("loading weights file {}".format(archive_file)) logger.info("loading configuration file {}".format(config_file)) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index caa9cf809c..dd195fc880 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -493,7 +493,6 @@ class GPT2PreTrainedModel(nn.Module): # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: logger.error( @@ -502,10 +501,27 @@ class GPT2PreTrainedModel(nn.Module): else: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " + "We assumed '{}' was a path or url but couldn't find file {} " "at this path or url.".format( pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, - archive_file, config_file + archive_file + ) + ) + return None + try: + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP: + logger.error( + "Couldn't reach server at '{}' to download pretrained model configuration file.".format( + config_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find file {} " + "at this path or url.".format( + pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, + config_file ) ) return None diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index d525c96e77..91848f3c68 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -496,7 +496,6 @@ class OpenAIGPTPreTrainedModel(nn.Module): # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: logger.error( @@ -505,10 +504,27 @@ class OpenAIGPTPreTrainedModel(nn.Module): else: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " + "We assumed '{}' was a path or url but couldn't find file {} " "at this path or url.".format( pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, - archive_file, config_file + archive_file + ) + ) + return None + try: + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP: + logger.error( + "Couldn't reach server at '{}' to download pretrained model configuration file.".format( + config_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find file {} " + "at this path or url.".format( + pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, + config_file ) ) return None diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index b3e829670a..534a111c77 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -921,7 +921,6 @@ class TransfoXLPreTrainedModel(nn.Module): # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: logger.error( @@ -930,12 +929,29 @@ class TransfoXLPreTrainedModel(nn.Module): else: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " + "We assumed '{}' was a path or url but couldn't find file {} " "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - archive_file, config_file)) + pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, + archive_file + ) + ) + return None + try: + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP: + logger.error( + "Couldn't reach server at '{}' to download pretrained model configuration file.".format( + config_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find file {} " + "at this path or url.".format( + pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, + config_file + ) + ) return None if resolved_archive_file == archive_file and resolved_config_file == config_file: logger.info("loading weights file {}".format(archive_file))