device agnostic models testing (#27146)
* device agnostic models testing * add decorator `require_torch_fp16` * make style * apply review suggestion * Oops, the fp16 decorator was misused
This commit is contained in:
@@ -22,7 +22,14 @@ import unittest
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoTokenizer, MistralConfig, is_torch_available
|
||||
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -450,7 +457,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
||||
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@@ -467,5 +474,5 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user