Remove sys.version_info[0] == 2 or 3.
This commit is contained in:
@@ -19,8 +19,6 @@ import json
|
||||
import logging
|
||||
import os.path
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
@@ -43,23 +41,6 @@ if is_torch_available():
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
|
||||
|
||||
else:
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -92,7 +73,7 @@ class ModelTesterMixin:
|
||||
out_2 = outputs[0].numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
model.to(torch_device)
|
||||
@@ -238,7 +219,7 @@ class ModelTesterMixin:
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with TemporaryDirectory() as tmp_dir_name:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
@@ -366,7 +347,7 @@ class ModelTesterMixin:
|
||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
||||
model.prune_heads(heads_to_prune)
|
||||
|
||||
with TemporaryDirectory() as temp_dir_name:
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model.to(torch_device)
|
||||
@@ -435,7 +416,7 @@ class ModelTesterMixin:
|
||||
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||
|
||||
with TemporaryDirectory() as temp_dir_name:
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model.to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user