device agnostic models testing (#27146)
* device agnostic models testing * add decorator `require_torch_fp16` * make style * apply review suggestion * Oops, the fp16 decorator was misused
This commit is contained in:
@@ -21,7 +21,14 @@ import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import PersimmonConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -413,12 +420,13 @@ class PersimmonIntegrationTest(unittest.TestCase):
|
||||
# fmt: on
|
||||
torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
def test_model_8b_chat_greedy_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = """human: Simply put, the theory of relativity states that?\n\nadept: The theory of relativity states that the laws of physics are the same for all observers, regardless of their relative motion."""
|
||||
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
|
||||
@@ -433,6 +441,6 @@ class PersimmonIntegrationTest(unittest.TestCase):
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user