relax network connection requirements

This commit is contained in:
thomwolf
2019-04-17 14:22:35 +02:00
parent fa76520240
commit 265550ec34
2 changed files with 23 additions and 8 deletions

View File

@@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
""" """
from __future__ import (absolute_import, division, print_function, unicode_literals) from __future__ import (absolute_import, division, print_function, unicode_literals)
import sys
import json import json
import logging import logging
import os import os
import shutil import shutil
import tempfile import tempfile
import fnmatch
from functools import wraps from functools import wraps
from hashlib import sha256 from hashlib import sha256
import sys import sys
@@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
if url.startswith("s3://"): if url.startswith("s3://"):
etag = s3_etag(url) etag = s3_etag(url)
else: else:
response = requests.head(url, allow_redirects=True) try:
if response.status_code != 200: response = requests.head(url, allow_redirects=True)
raise IOError("HEAD request failed for url {} with status code {}" if response.status_code != 200:
.format(url, response.status_code)) etag = None
etag = response.headers.get("ETag") else:
etag = response.headers.get("ETag")
except EnvironmentError:
etag = None
if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
# get cache path to put the file # get cache path to put the file
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
@@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None):
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file: with open(meta_path, 'w') as meta_file:
json.dump(meta, meta_file) meta_file.write(json.dumps(meta, indent=4))
logger.info("removing temp file %s", temp_file.name) logger.info("removing temp file %s", temp_file.name)

View File

@@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase):
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow # @pytest.mark.slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/" cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: