device agnostic models testing (#27146)
* device agnostic models testing * add decorator `require_torch_fp16` * make style * apply review suggestion * Oops, the fp16 decorator was misused
This commit is contained in:
@@ -21,7 +21,14 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2ConformerConfig, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
@@ -468,12 +475,14 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
def test_model_float16_with_relative(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
|
||||
self.model_tester.create_and_check_model_float16(*config_and_inputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
def test_model_float16_with_rotary(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
|
||||
self.model_tester.create_and_check_model_float16(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user