Add kyutai stt (#38909)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* first draft

* cleaner version

* udpate tests + modeling

* add tests

* init

* udpate test_modeling_common

* fix tests

* csm Processor draft

* convertion update

* mimi cache padding convolutions draft

* mimi streaming udpates

* update mimi padding cache test

* udpate cache padding mimi test

* make style mimi

* updates generate moshi asr

* moshi asr integration tests (single + batched)

* update tests

* update conversion script

* good default sliding window value

* udpdate generate

* update test checkpoint

* nit

* fix mimi

* fix codec prefix

* revert

* revert

* update config

* update config

* unnecessary mimi input restriction

* remove delay in tokens

* remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask

* test update

* modular update

* make style

* nit

* rename

* create codec model generation config at init

* remove delay

* max_new_tokens/length warning

* correct conv1 padding cache import for modular

* nit

* fix on encoder_past_key_values

* convert modular

* move frame_size to config

* move frame_size to config

* update test name

* handle first token is bos

* better handling of max_new_tokens

* fix

* fix batch size in test input prep

* update docstring

* convert modular

* make style

* make style

* add feature extractor

* correct modular convention name for feature_extraction file

* update convertion script

* doc processor

* update doc

* udpate init

* update model type

* fixes

* update tests

* fix

* make

* add doc

* nit

* fix

* doc

* auto mappings

* doc

* nit

* convert modular

* doc

* nit

* extend _keep_in_fp32_modules to enforce fp32

* renaming to stt

* doc update + test update

* doc fixes

* doc fix

* doc fix

* fix musicgen tests

* fix musicgen tests

* make style

* fix musicgen tests

* correct frame_rate config param for mimi

* update mimi test

* revert update mimi test

* enforce cpu test

* move cache init in cache class

* convert modular

* docstring update

* update model id

* feature_extractor -> feature_extraction (SEW)

* convert modular

* update model id
This commit is contained in:
eustlb
2025-06-24 18:01:15 +02:00
committed by GitHub
parent 08bf7f1afe
commit 6bdd4ec952
23 changed files with 4000 additions and 200 deletions

View File

@@ -107,14 +107,21 @@ class MimiModelTester:
self.sliding_window = sliding_window
self.use_cache = use_cache
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
def prepare_config_and_inputs(self, input_values_length=None):
input_values = floats_tensor(
[
self.batch_size,
self.num_channels,
self.intermediate_size if input_values_length is None else input_values_length,
],
scale=1.0,
)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
def prepare_config_and_inputs_for_common(self, input_values_length=None):
config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length)
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
@@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase):
)
self.assertTrue(rmse < 1e-3)
def test_integration_encode_with_padding_cache(self):
"""
We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk.
1. we encode a first time the entire audio
2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size)
This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures.
"""
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_id = "kyutai/mimi"
model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu")
processor = AutoFeatureExtractor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[-1]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to("cpu")
frame_size = model.config.frame_size
audio_codes = model.encode(inputs["input_values"]).audio_codes
# streaming chunk by chunk
encoder_past_key_values = None
padding_cache = None
encoded_frames_list = []
for start in range(0, inputs["input_values"].shape[-1], frame_size):
input_values_chunk = inputs["input_values"][:, :, start : start + frame_size]
encoder_outputs = model.encode(
input_values_chunk,
padding_cache=padding_cache,
encoder_past_key_values=encoder_past_key_values,
use_streaming=True,
)
encoder_past_key_values = encoder_outputs.encoder_past_key_values
padding_cache = encoder_outputs.padding_cache
encoded_frames_list.append(encoder_outputs.audio_codes)
streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1)
torch.testing.assert_close(streamed_audio_codes, audio_codes)
def test_integration(self):
expected_rmses = {
"8": 0.0018785292,