From 265550ec34bfa756538c60e0d5d4c906ee78e1ce Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 17 Apr 2019 14:22:35 +0200 Subject: [PATCH] relax network connection requirements --- pytorch_pretrained_bert/file_utils.py | 29 ++++++++++++++++++++------- tests/tokenization_gpt2_test.py | 2 +- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pytorch_pretrained_bert/file_utils.py b/pytorch_pretrained_bert/file_utils.py index 6de7e259e5..e7e1714f97 100644 --- a/pytorch_pretrained_bert/file_utils.py +++ b/pytorch_pretrained_bert/file_utils.py @@ -5,11 +5,13 @@ Copyright by the AllenNLP authors. """ from __future__ import (absolute_import, division, print_function, unicode_literals) +import sys import json import logging import os import shutil import tempfile +import fnmatch from functools import wraps from hashlib import sha256 import sys @@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None): if url.startswith("s3://"): etag = s3_etag(url) else: - response = requests.head(url, allow_redirects=True) - if response.status_code != 200: - raise IOError("HEAD request failed for url {} with status code {}" - .format(url, response.status_code)) - etag = response.headers.get("ETag") + try: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except EnvironmentError: + etag = None + if sys.version_info[0] == 2 and etag is not None: + etag = etag.decode('utf-8') filename = url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') + matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + if not os.path.exists(cache_path): # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. @@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None): logger.info("creating metadata file for %s", cache_path) meta = {'url': url, 'etag': etag} meta_path = cache_path + '.json' - with open(meta_path, 'w', encoding="utf-8") as meta_file: - json.dump(meta, meta_file) + with open(meta_path, 'w') as meta_file: + meta_file.write(json.dumps(meta, indent=4)) logger.info("removing temp file %s", temp_file.name) diff --git a/tests/tokenization_gpt2_test.py b/tests/tokenization_gpt2_test.py index 870f61ca79..cfd13de391 100644 --- a/tests/tokenization_gpt2_test.py +++ b/tests/tokenization_gpt2_test.py @@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase): [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) - @pytest.mark.slow + # @pytest.mark.slow def test_tokenizer_from_pretrained(self): cache_dir = "/tmp/pytorch_pretrained_bert_test/" for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: