Enable granite speech 3.3 tests (#37560)
* Enable granite speech 3.3 tests * skip sdpa test for granite speech * Explicitly move model to device * Use granite speech 2b in tests --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
|
is_peft_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -306,11 +307,17 @@ class GraniteSpeechForConditionalGenerationModelTest(ModelTesterMixin, Generatio
|
|||||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||||
raise ValueError("The eager model should not have SDPA attention layers")
|
raise ValueError("The eager model should not have SDPA attention layers")
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@require_torch_sdpa
|
||||||
|
@slow
|
||||||
|
@unittest.skip(reason="Granite Speech doesn't support SDPA for all backbones")
|
||||||
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# TODO - use the actual model path on HF hub after release.
|
self.model_path = "ibm-granite/granite-speech-3.3-2b"
|
||||||
self.model_path = "ibm-granite/granite-speech"
|
|
||||||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||||||
self.prompt = self._get_prompt(self.processor.tokenizer)
|
self.prompt = self._get_prompt(self.processor.tokenizer)
|
||||||
|
|
||||||
@@ -338,7 +345,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
return [x["array"] for x in speech_samples]
|
return [x["array"] for x in speech_samples]
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@pytest.mark.skip("Public models not yet available")
|
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
|
||||||
def test_small_model_integration_test_single(self):
|
def test_small_model_integration_test_single(self):
|
||||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
|
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
|
||||||
input_speech = self._load_datasamples(1)
|
input_speech = self._load_datasamples(1)
|
||||||
@@ -364,9 +371,9 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@pytest.mark.skip("Public models not yet available")
|
@pytest.mark.skipif(not is_peft_available(), reason="Outputs diverge without lora")
|
||||||
def test_small_model_integration_test_batch(self):
|
def test_small_model_integration_test_batch(self):
|
||||||
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path)
|
model = GraniteSpeechForConditionalGeneration.from_pretrained(self.model_path).to(torch_device)
|
||||||
input_speech = self._load_datasamples(2)
|
input_speech = self._load_datasamples(2)
|
||||||
prompts = [self.prompt, self.prompt]
|
prompts = [self.prompt, self.prompt]
|
||||||
|
|
||||||
@@ -384,7 +391,7 @@ class GraniteSpeechForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_DECODED_TEXT = [
|
EXPECTED_DECODED_TEXT = [
|
||||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantmister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||||
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilter's manner less interesting than his matter"
|
"systemKnowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant\nusercan you transcribe the speech into a written format?\nassistantnor is mister quilp's manner less interesting than his matter"
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|||||||
@@ -33,14 +33,12 @@ if is_torchaudio_available():
|
|||||||
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
|
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor
|
||||||
|
|
||||||
|
|
||||||
@pytest.skip("Public models not yet available", allow_module_level=True)
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
class GraniteSpeechProcessorTest(unittest.TestCase):
|
class GraniteSpeechProcessorTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
# TODO - use the actual model path on HF hub after release.
|
self.checkpoint = "ibm-granite/granite-speech-3.3-8b"
|
||||||
self.checkpoint = "ibm-granite/granite-speech"
|
|
||||||
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
|
processor = GraniteSpeechProcessor.from_pretrained(self.checkpoint)
|
||||||
processor.save_pretrained(self.tmpdirname)
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user