Remove sys.version_info[0] == 2 or 3.

This commit is contained in:
Aymeric Augustin
2019-12-22 18:12:11 +01:00
parent 8af25b1664
commit 798b3b3899
18 changed files with 41 additions and 170 deletions

View File

@@ -19,8 +19,6 @@ import json
import logging
import os.path
import random
import shutil
import sys
import tempfile
import unittest
import uuid
@@ -43,23 +41,6 @@ if is_torch_available():
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
if sys.version_info[0] == 2:
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
@@ -92,7 +73,7 @@ class ModelTesterMixin:
out_2 = outputs[0].numpy()
out_2[np.isnan(out_2)] = 0
with TemporaryDirectory() as tmpdirname:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
model.to(torch_device)
@@ -238,7 +219,7 @@ class ModelTesterMixin:
except RuntimeError:
self.fail("Couldn't trace module.")
with TemporaryDirectory() as tmp_dir_name:
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
@@ -366,7 +347,7 @@ class ModelTesterMixin:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
model.prune_heads(heads_to_prune)
with TemporaryDirectory() as temp_dir_name:
with tempfile.TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device)
@@ -435,7 +416,7 @@ class ModelTesterMixin:
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
with TemporaryDirectory() as temp_dir_name:
with tempfile.TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device)