Remove sys.version_info[0] == 2 or 3.
This commit is contained in:
@@ -15,11 +15,11 @@
|
||||
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .test_tokenization_common import TemporaryDirectory
|
||||
from .utils import require_torch
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
|
||||
scheduler.step()
|
||||
lrs.append(scheduler.get_lr())
|
||||
if step == num_steps // 2:
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_name = os.path.join(tmpdirname, "schedule.bin")
|
||||
torch.save(scheduler.state_dict(), file_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user