test suite independent of framework
This commit is contained in:
@@ -18,11 +18,17 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from pytorch_transformers import is_torch_available
|
||||
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
try:
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
|
||||
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
except ImportError:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
from .tokenization_tests_commons import TemporaryDirectory
|
||||
|
||||
@@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase):
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optimizer = AdamW(m.parameters(), lr=10.)
|
||||
m = torch.nn.Linear(50, 50) if is_torch_available() else None
|
||||
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
|
||||
num_steps = 10
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
|
||||
Reference in New Issue
Block a user