Take advantage of the cache when running tests.
Caching models across test cases and across runs of the test suite makes slow tests somewhat more bearable. Use gettempdir() instead of /tmp in tests. This makes it easier to change the location of the cache with semi-standard TMPDIR/TEMP/TMP environment variables. Fix #2222.
This commit is contained in:
@@ -17,12 +17,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_tf, slow
|
||||
from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
from transformers import XxxConfig, is_tf_available
|
||||
|
||||
@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['xxx-base-uncased']:
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,13 +17,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
|
||||
@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user