Remove sys.version_info[0] == 2 or 3.
This commit is contained in:
@@ -17,8 +17,6 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
@@ -32,23 +30,6 @@ if is_tf_available():
|
||||
|
||||
# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
|
||||
|
||||
else:
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -87,7 +68,7 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(inputs_dict)
|
||||
@@ -137,7 +118,7 @@ class TFModelTesterMixin:
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
@@ -180,7 +161,7 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
outputs = model(inputs_dict) # build the model
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
Reference in New Issue
Block a user