test suite independent of framework

This commit is contained in:
thomwolf
2019-09-05 11:18:55 +02:00
parent 9d0a11a68c
commit 518307dfcd
20 changed files with 596 additions and 262 deletions

View File

@@ -18,11 +18,17 @@ from __future__ import print_function
import unittest
import os
import pytest
import torch
from pytorch_transformers import is_torch_available
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
try:
import torch
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
except ImportError:
pytestmark = pytest.mark.skip("Require Torch")
from .tokenization_tests_commons import TemporaryDirectory
@@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50)
optimizer = AdamW(m.parameters(), lr=10.)
m = torch.nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol):