enable csm integration cases on xpu, all passed (#38140)
* enable csm test cases on XPU, all passed Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
cleanup,
|
cleanup,
|
||||||
require_torch_gpu,
|
require_torch_accelerator,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -430,7 +430,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
return ds[0]
|
return ds[0]
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_1b_model_integration_generate(self):
|
def test_1b_model_integration_generate(self):
|
||||||
"""
|
"""
|
||||||
Tests the generated tokens match the ones from the original model implementation.
|
Tests the generated tokens match the ones from the original model implementation.
|
||||||
@@ -474,7 +474,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_1b_model_integration_generate_no_audio(self):
|
def test_1b_model_integration_generate_no_audio(self):
|
||||||
"""
|
"""
|
||||||
Tests the generated tokens match the ones from the original model implementation.
|
Tests the generated tokens match the ones from the original model implementation.
|
||||||
@@ -535,7 +535,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_1b_model_integration_generate_multiple_audio(self):
|
def test_1b_model_integration_generate_multiple_audio(self):
|
||||||
"""
|
"""
|
||||||
Test the generated tokens match the ones from the original model implementation.
|
Test the generated tokens match the ones from the original model implementation.
|
||||||
@@ -594,7 +594,7 @@ class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_1b_model_integration_generate_batched(self):
|
def test_1b_model_integration_generate_batched(self):
|
||||||
"""
|
"""
|
||||||
Test the generated tokens match the ones from the original model implementation.
|
Test the generated tokens match the ones from the original model implementation.
|
||||||
|
|||||||
Reference in New Issue
Block a user