From c603d099aa24410ec5a60c23794cc4a293d92850 Mon Sep 17 00:00:00 2001 From: Abhishek Rao Date: Thu, 22 Aug 2019 15:25:40 -0700 Subject: [PATCH] reraise EnvironmentError in from_pretrained functions of Model and Tokenizer --- pytorch_transformers/modeling_utils.py | 4 ++-- pytorch_transformers/tokenization_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 5066c42595..468d240fbc 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -473,7 +473,7 @@ class PreTrainedModel(nn.Module): # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError: + except EnvironmentError as e: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( "Couldn't reach server at '{}' to download pretrained weights.".format( @@ -486,7 +486,7 @@ class PreTrainedModel(nn.Module): pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), archive_file)) - return None + raise e if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index d2855e0922..4fef0e34fb 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -293,7 +293,7 @@ class PreTrainedTokenizer(object): resolved_vocab_files[file_id] = None else: resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError: + except EnvironmentError as e: if pretrained_model_name_or_path in s3_models: logger.error("Couldn't reach server to download vocabulary.") else: @@ -303,7 +303,7 @@ class PreTrainedTokenizer(object): "at this path or url.".format( pretrained_model_name_or_path, ', '.join(s3_models), pretrained_model_name_or_path, str(vocab_files.keys()))) - return None + raise e for file_id, file_path in vocab_files.items(): if file_path == resolved_vocab_files[file_id]: