Remove sys.version_info[0] == 2 or 3.

This commit is contained in:
Aymeric Augustin
2019-12-22 18:12:11 +01:00
parent 8af25b1664
commit 798b3b3899
18 changed files with 41 additions and 170 deletions

View File

@@ -17,8 +17,6 @@
import copy
import os
import random
import shutil
import sys
import tempfile
from transformers import is_tf_available, is_torch_available
@@ -32,23 +30,6 @@ if is_tf_available():
# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
if sys.version_info[0] == 2:
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
@@ -87,7 +68,7 @@ class TFModelTesterMixin:
model = model_class(config)
outputs = model(inputs_dict)
with TemporaryDirectory() as tmpdirname:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict)
@@ -137,7 +118,7 @@ class TFModelTesterMixin:
self.assertLessEqual(max_diff, 2e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with TemporaryDirectory() as tmpdirname:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
@@ -180,7 +161,7 @@ class TFModelTesterMixin:
model = model_class(config)
# Let's load it from the disk to be sure we can use pretrained weights
with TemporaryDirectory() as tmpdirname:
with tempfile.TemporaryDirectory() as tmpdirname:
outputs = model(inputs_dict) # build the model
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)