Enable granite speech 3.3 tests (#37560)

* Enable granite speech 3.3 tests

* skip sdpa test for granite speech

* Explicitly move model to device

* Use granite speech 2b in tests

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Alex Brooks
2025-05-06 09:56:18 -06:00
committed by GitHub
parent 031ef8802c
commit 06c4d05fe6
2 changed files with 14 additions and 9 deletions

View File

@@ -33,6 +33,7 @@ from transformers.testing_utils import (
) )
from transformers.utils import ( from transformers.utils import (
is_datasets_available, is_datasets_available,
is_peft_available,
is_torch_available, is_torch_available,
) )
@@ -306,11 +307,17 @@ class GraniteSpeechForConditionalGenerationModelTest(ModelTesterMixin, Generatio
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
@pytest.mark.generate
@require_torch_sdpa
@slow
@unittest.skip(reason="Granite Speech doesn't support SDPA for all backbones")
def test_eager_matches_sdpa_generate(self):
pass
class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase): class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self): def setUp(self):
# TODO - use the actual model path on HF hub after release. self.model_path = "ibm-granite/granite-speech-3.3-2b"
self.model_path = "ibm-granite/granite-speech"
self.processor = AutoProcessor.from_pretrained(self.model_path) self.processor = AutoProcessor.from_pretrained(self.model_path)
self.prompt = self._get_prompt(self.processor.tokenizer) self.prompt = self._get_prompt(self.processor.tokenizer)
@@ -338,7 +345,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
return [x["array"] for x in speech_samples] return [x["array"] for x in speech_samples]
@slow @slow
@pytest.mark.skip("Public models not yet available") @pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
def test_small_model_integration_test_single(self): def test_small_model_integration_test_single(self):
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device) model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
@@ -364,9 +371,9 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
) )
@slow @slow
@pytest.mark.skip("Public models not yet available") @pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
def test_small_model_integration_test_batch(self): def test_small_model_integration_test_batch(self):
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path) model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
input_speech = self._load_datasamples(2) input_speech = self._load_datasamples(2)
prompts = [self.prompt, self.prompt] prompts = [self.prompt, self.prompt]
@@ -384,7 +391,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = [ EXPECTED_DECODED_TEXT = [
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel", "systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilter's manner less interesting than his matter" "systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilp's manner less interesting than his matter"
] # fmt: skip ] # fmt: skip
self.assertEqual( self.assertEqual(

View File

@@ -33,14 +33,12 @@ if is_torchaudio_available():
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
@pytest.skip("Public models not yet available", allow_module_level=True)
@require_torch @require_torch
@require_torchaudio @require_torchaudio
class GraniteSpeechProcessorTest(unittest.TestCase): class GraniteSpeechProcessorTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
# TODO - use the actual model path on HF hub after release. self.checkpoint = "ibm-granite/granite-speech-3.3-8b"
self.checkpoint = "ibm-granite/granite-speech"
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint) processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
processor.save_pretrained(self.tmpdirname) processor.save_pretrained(self.tmpdirname)