device agnostic models testing (#27146)
* device agnostic models testing * add decorator `require_torch_fp16` * make style * apply review suggestion * Oops, the fp16 decorator was misused
This commit is contained in:
@@ -19,7 +19,14 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SwitchTransformersConfig, is_torch_available
|
||||
from transformers.testing_utils import require_tokenizers, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_bf16,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -1017,7 +1024,8 @@ class SwitchTransformerRouterTest(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_bf16
|
||||
def test_small_logits(self):
|
||||
r"""
|
||||
Logits testing to check implementation consistency between `t5x` implementation
|
||||
|
||||
Reference in New Issue
Block a user