diff --git a/templates/adding_a_new_model/tests/modeling_tf_xxx_test.py b/templates/adding_a_new_model/tests/modeling_tf_xxx_test.py index 912a4aa340..6eba932a8e 100644 --- a/templates/adding_a_new_model/tests/modeling_tf_xxx_test.py +++ b/templates/adding_a_new_model/tests/modeling_tf_xxx_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import XxxConfig, is_tf_available @@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in ['xxx-base-uncased']: - model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/templates/adding_a_new_model/tests/modeling_xxx_test.py b/templates/adding_a_new_model/tests/modeling_xxx_test.py index 30e614b3f2..5e22392d00 100644 --- a/templates/adding_a_new_model/tests/modeling_xxx_test.py +++ b/templates/adding_a_new_model/tests/modeling_xxx_test.py @@ -17,13 +17,12 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): from transformers import (XxxConfig, XxxModel, XxxForMaskedLM, @@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_albert_test.py b/transformers/tests/modeling_albert_test.py index 1911d244e7..b726fd9278 100644 --- a/transformers/tests/modeling_albert_test.py +++ b/transformers/tests/modeling_albert_test.py @@ -17,13 +17,12 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM, @@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_bert_test.py b/transformers/tests/modeling_bert_test.py index 0eb7bc9a14..a5adff8f68 100644 --- a/transformers/tests/modeling_bert_test.py +++ b/transformers/tests/modeling_bert_test.py @@ -17,13 +17,12 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): from transformers import (BertConfig, BertModel, BertForMaskedLM, @@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index cf36332207..2116651f4a 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -30,7 +30,7 @@ import logging from transformers import is_torch_available -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): import torch @@ -753,10 +753,8 @@ class CommonTestCases: [[], []]) def create_and_check_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]: - model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR) self.parent.assertIsNotNone(model) def prepare_config_and_inputs_for_common(self): diff --git a/transformers/tests/modeling_ctrl_test.py b/transformers/tests/modeling_ctrl_test.py index c7de49b2ab..ed0d62d1e6 100644 --- a/transformers/tests/modeling_ctrl_test.py +++ b/transformers/tests/modeling_ctrl_test.py @@ -16,7 +16,6 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import pdb from transformers import is_torch_available @@ -27,7 +26,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_distilbert_test.py b/transformers/tests/modeling_distilbert_test.py index 82f71c40da..ac6f5d248e 100644 --- a/transformers/tests/modeling_distilbert_test.py +++ b/transformers/tests/modeling_distilbert_test.py @@ -27,7 +27,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): # @slow # def test_model_from_pretrained(self): - # cache_dir = "/tmp/transformers_test/" # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) - # shutil.rmtree(cache_dir) + # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) # self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_gpt2_test.py b/transformers/tests/modeling_gpt2_test.py index a82e39c261..ad2ec1fd91 100644 --- a/transformers/tests/modeling_gpt2_test.py +++ b/transformers/tests/modeling_gpt2_test.py @@ -17,7 +17,6 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available @@ -27,7 +26,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_openai_test.py b/transformers/tests/modeling_openai_test.py index 7655e432e8..1880febcae 100644 --- a/transformers/tests/modeling_openai_test.py +++ b/transformers/tests/modeling_openai_test.py @@ -17,7 +17,6 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available @@ -27,7 +26,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_roberta_test.py b/transformers/tests/modeling_roberta_test.py index 4d34a50528..299cbd01ad 100644 --- a/transformers/tests/modeling_roberta_test.py +++ b/transformers/tests/modeling_roberta_test.py @@ -17,7 +17,6 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available @@ -29,7 +28,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_t5_test.py b/transformers/tests/modeling_t5_test.py index c337163375..9fd9a4b304 100644 --- a/transformers/tests/modeling_t5_test.py +++ b/transformers/tests/modeling_t5_test.py @@ -17,13 +17,12 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): from transformers import (T5Config, T5Model, T5WithLMHeadModel) @@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = T5Model.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_albert_test.py b/transformers/tests/modeling_tf_albert_test.py index 93aeab66c2..ee71371a18 100644 --- a/transformers/tests/modeling_tf_albert_test.py +++ b/transformers/tests/modeling_tf_albert_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import AlbertConfig, is_tf_available @@ -217,12 +216,9 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" # for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in ['albert-base-uncased']: - model = TFAlbertModel.from_pretrained( - model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_tf_auto_test.py b/transformers/tests/modeling_tf_auto_test.py index 7ab6eaa3d6..2ad39ddccf 100644 --- a/transformers/tests/modeling_tf_auto_test.py +++ b/transformers/tests/modeling_tf_auto_test.py @@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase): logging.basicConfig(level=logging.INFO) # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in ['bert-base-uncased']: - config = AutoConfig.from_pretrained(model_name, force_download=True) + config = AutoConfig.from_pretrained(model_name) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) - model = TFAutoModel.from_pretrained(model_name, force_download=True) + model = TFAutoModel.from_pretrained(model_name) self.assertIsNotNone(model) self.assertIsInstance(model, TFBertModel) @@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase): logging.basicConfig(level=logging.INFO) # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in ['bert-base-uncased']: - config = AutoConfig.from_pretrained(model_name, force_download=True) + config = AutoConfig.from_pretrained(model_name) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) - model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) + model = TFAutoModelWithLMHead.from_pretrained(model_name) self.assertIsNotNone(model) self.assertIsInstance(model, TFBertForMaskedLM) @@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase): logging.basicConfig(level=logging.INFO) # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in ['bert-base-uncased']: - config = AutoConfig.from_pretrained(model_name, force_download=True) + config = AutoConfig.from_pretrained(model_name) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) - model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) + model = TFAutoModelForSequenceClassification.from_pretrained(model_name) self.assertIsNotNone(model) self.assertIsInstance(model, TFBertForSequenceClassification) @@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase): logging.basicConfig(level=logging.INFO) # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in ['bert-base-uncased']: - config = AutoConfig.from_pretrained(model_name, force_download=True) + config = AutoConfig.from_pretrained(model_name) self.assertIsNotNone(config) self.assertIsInstance(config, BertConfig) - model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) + model = TFAutoModelForQuestionAnswering.from_pretrained(model_name) self.assertIsNotNone(model) self.assertIsInstance(model, TFBertForQuestionAnswering) def test_from_pretrained_identifier(self): logging.basicConfig(level=logging.INFO) - model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) + model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) self.assertIsInstance(model, TFBertForMaskedLM) diff --git a/transformers/tests/modeling_tf_bert_test.py b/transformers/tests/modeling_tf_bert_test.py index 20073e1ab8..abf20b1514 100644 --- a/transformers/tests/modeling_tf_bert_test.py +++ b/transformers/tests/modeling_tf_bert_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import BertConfig, is_tf_available @@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in ['bert-base-uncased']: - model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_ctrl_test.py b/transformers/tests/modeling_tf_ctrl_test.py index 0876582e57..93b231e517 100644 --- a/transformers/tests/modeling_tf_ctrl_test.py +++ b/transformers/tests/modeling_tf_ctrl_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import CTRLConfig, is_tf_available @@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_distilbert_test.py b/transformers/tests/modeling_tf_distilbert_test.py index d9e971c2a5..f28b5c397b 100644 --- a/transformers/tests/modeling_tf_distilbert_test.py +++ b/transformers/tests/modeling_tf_distilbert_test.py @@ -20,7 +20,7 @@ import unittest from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import DistilBertConfig, is_tf_available @@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): # @slow # def test_model_from_pretrained(self): - # cache_dir = "/tmp/transformers_test/" # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - # model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) - # shutil.rmtree(cache_dir) + # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) # self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_gpt2_test.py b/transformers/tests/modeling_tf_gpt2_test.py index 3f30b32787..90920342ba 100644 --- a/transformers/tests/modeling_tf_gpt2_test.py +++ b/transformers/tests/modeling_tf_gpt2_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import GPT2Config, is_tf_available @@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_openai_gpt_test.py b/transformers/tests/modeling_tf_openai_gpt_test.py index 863dbf1bc0..065bf2acde 100644 --- a/transformers/tests/modeling_tf_openai_gpt_test.py +++ b/transformers/tests/modeling_tf_openai_gpt_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import OpenAIGPTConfig, is_tf_available @@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_roberta_test.py b/transformers/tests/modeling_tf_roberta_test.py index f4ed97c44b..93c478ae28 100644 --- a/transformers/tests/modeling_tf_roberta_test.py +++ b/transformers/tests/modeling_tf_roberta_test.py @@ -17,11 +17,10 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import RobertaConfig, is_tf_available @@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFRobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_tf_t5_test.py b/transformers/tests/modeling_tf_t5_test.py index b905a9875b..da9ce6f89d 100644 --- a/transformers/tests/modeling_tf_t5_test.py +++ b/transformers/tests/modeling_tf_t5_test.py @@ -17,12 +17,11 @@ from __future__ import division from __future__ import print_function import unittest -import shutil import sys from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import T5Config, is_tf_available @@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in ['t5-small']: - model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) if __name__ == "__main__": diff --git a/transformers/tests/modeling_tf_transfo_xl_test.py b/transformers/tests/modeling_tf_transfo_xl_test.py index 746a6a1321..8225c09275 100644 --- a/transformers/tests/modeling_tf_transfo_xl_test.py +++ b/transformers/tests/modeling_tf_transfo_xl_test.py @@ -18,11 +18,10 @@ from __future__ import print_function import unittest import random -import shutil from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow from transformers import TransfoXLConfig, is_tf_available @@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_tf_xlm_test.py b/transformers/tests/modeling_tf_xlm_test.py index 228e436149..8b5ab6d742 100644 --- a/transformers/tests/modeling_tf_xlm_test.py +++ b/transformers/tests/modeling_tf_xlm_test.py @@ -17,7 +17,6 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_tf_available @@ -31,7 +30,7 @@ if is_tf_available(): from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow @require_tf @@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_tf_xlnet_test.py b/transformers/tests/modeling_tf_xlnet_test.py index eb66d92793..15fd917481 100644 --- a/transformers/tests/modeling_tf_xlnet_test.py +++ b/transformers/tests/modeling_tf_xlnet_test.py @@ -20,7 +20,6 @@ import os import unittest import json import random -import shutil from transformers import XLNetConfig, is_tf_available @@ -35,7 +34,7 @@ if is_tf_available(): from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_tf, slow +from .utils import CACHE_DIR, require_tf, slow @require_tf @@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_transfo_xl_test.py b/transformers/tests/modeling_transfo_xl_test.py index f41d50a3a0..acbe95fe4a 100644 --- a/transformers/tests/modeling_transfo_xl_test.py +++ b/transformers/tests/modeling_transfo_xl_test.py @@ -18,7 +18,6 @@ from __future__ import print_function import unittest import random -import shutil from transformers import is_torch_available @@ -29,7 +28,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_xlm_test.py b/transformers/tests/modeling_xlm_test.py index 7cae6c848e..fcc2f4699b 100644 --- a/transformers/tests/modeling_xlm_test.py +++ b/transformers/tests/modeling_xlm_test.py @@ -17,7 +17,6 @@ from __future__ import division from __future__ import print_function import unittest -import shutil from transformers import is_torch_available @@ -28,7 +27,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index 6d901ee699..6d218d6ef4 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -20,7 +20,6 @@ import os import unittest import json import random -import shutil from transformers import is_torch_available @@ -33,7 +32,7 @@ if is_torch_available(): from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -from .utils import require_torch, slow, torch_device +from .utils import CACHE_DIR, require_torch, slow, torch_device @require_torch @@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): @slow def test_model_from_pretrained(self): - cache_dir = "/tmp/transformers_test/" for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) + model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model) diff --git a/transformers/tests/utils.py b/transformers/tests/utils.py index c950ad8f17..ba0e19f420 100644 --- a/transformers/tests/utils.py +++ b/transformers/tests/utils.py @@ -1,11 +1,14 @@ import os import unittest +import tempfile from distutils.util import strtobool from transformers.file_utils import _tf_available, _torch_available +CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test") + SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"