relax network connection requirements
This commit is contained in:
@@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
|
|||||||
"""
|
"""
|
||||||
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
||||||
|
|
||||||
|
import sys
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import fnmatch
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
import sys
|
import sys
|
||||||
@@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
etag = s3_etag(url)
|
etag = s3_etag(url)
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
response = requests.head(url, allow_redirects=True)
|
response = requests.head(url, allow_redirects=True)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise IOError("HEAD request failed for url {} with status code {}"
|
etag = None
|
||||||
.format(url, response.status_code))
|
else:
|
||||||
etag = response.headers.get("ETag")
|
etag = response.headers.get("ETag")
|
||||||
|
except EnvironmentError:
|
||||||
|
etag = None
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2 and etag is not None:
|
||||||
|
etag = etag.decode('utf-8')
|
||||||
filename = url_to_filename(url, etag)
|
filename = url_to_filename(url, etag)
|
||||||
|
|
||||||
# get cache path to put the file
|
# get cache path to put the file
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
cache_path = os.path.join(cache_dir, filename)
|
||||||
|
|
||||||
|
# If we don't have a connection (etag is None) and can't identify the file
|
||||||
|
# try to get the last downloaded one
|
||||||
|
if not os.path.exists(cache_path) and etag is None:
|
||||||
|
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
||||||
|
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
|
||||||
|
if matching_files:
|
||||||
|
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||||
|
|
||||||
if not os.path.exists(cache_path):
|
if not os.path.exists(cache_path):
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
# Download to temporary file, then copy to cache dir once finished.
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||||
@@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None):
|
|||||||
logger.info("creating metadata file for %s", cache_path)
|
logger.info("creating metadata file for %s", cache_path)
|
||||||
meta = {'url': url, 'etag': etag}
|
meta = {'url': url, 'etag': etag}
|
||||||
meta_path = cache_path + '.json'
|
meta_path = cache_path + '.json'
|
||||||
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
with open(meta_path, 'w') as meta_file:
|
||||||
json.dump(meta, meta_file)
|
meta_file.write(json.dumps(meta, indent=4))
|
||||||
|
|
||||||
logger.info("removing temp file %s", temp_file.name)
|
logger.info("removing temp file %s", temp_file.name)
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase):
|
|||||||
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||||
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||||
|
|
||||||
@pytest.mark.slow
|
# @pytest.mark.slow
|
||||||
def test_tokenizer_from_pretrained(self):
|
def test_tokenizer_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user