Add kyutai stt (#38909)
* first draft * cleaner version * udpate tests + modeling * add tests * init * udpate test_modeling_common * fix tests * csm Processor draft * convertion update * mimi cache padding convolutions draft * mimi streaming udpates * update mimi padding cache test * udpate cache padding mimi test * make style mimi * updates generate moshi asr * moshi asr integration tests (single + batched) * update tests * update conversion script * good default sliding window value * udpdate generate * update test checkpoint * nit * fix mimi * fix codec prefix * revert * revert * update config * update config * unnecessary mimi input restriction * remove delay in tokens * remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask * test update * modular update * make style * nit * rename * create codec model generation config at init * remove delay * max_new_tokens/length warning * correct conv1 padding cache import for modular * nit * fix on encoder_past_key_values * convert modular * move frame_size to config * move frame_size to config * update test name * handle first token is bos * better handling of max_new_tokens * fix * fix batch size in test input prep * update docstring * convert modular * make style * make style * add feature extractor * correct modular convention name for feature_extraction file * update convertion script * doc processor * update doc * udpate init * update model type * fixes * update tests * fix * make * add doc * nit * fix * doc * auto mappings * doc * nit * convert modular * doc * nit * extend _keep_in_fp32_modules to enforce fp32 * renaming to stt * doc update + test update * doc fixes * doc fix * doc fix * fix musicgen tests * fix musicgen tests * make style * fix musicgen tests * correct frame_rate config param for mimi * update mimi test * revert update mimi test * enforce cpu test * move cache init in cache class * convert modular * docstring update * update model id * feature_extractor -> feature_extraction (SEW) * convert modular * update model id
This commit is contained in:
@@ -107,14 +107,21 @@ class MimiModelTester:
|
||||
self.sliding_window = sliding_window
|
||||
self.use_cache = use_cache
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
||||
def prepare_config_and_inputs(self, input_values_length=None):
|
||||
input_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.num_channels,
|
||||
self.intermediate_size if input_values_length is None else input_values_length,
|
||||
],
|
||||
scale=1.0,
|
||||
)
|
||||
config = self.get_config()
|
||||
inputs_dict = {"input_values": input_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
def prepare_config_and_inputs_for_common(self, input_values_length=None):
|
||||
config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length)
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_model_class(self, model_class):
|
||||
@@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(rmse < 1e-3)
|
||||
|
||||
def test_integration_encode_with_padding_cache(self):
|
||||
"""
|
||||
We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk.
|
||||
1. we encode a first time the entire audio
|
||||
2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size)
|
||||
|
||||
This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures.
|
||||
"""
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_id = "kyutai/mimi"
|
||||
|
||||
model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu")
|
||||
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to("cpu")
|
||||
|
||||
frame_size = model.config.frame_size
|
||||
audio_codes = model.encode(inputs["input_values"]).audio_codes
|
||||
|
||||
# streaming chunk by chunk
|
||||
encoder_past_key_values = None
|
||||
padding_cache = None
|
||||
encoded_frames_list = []
|
||||
|
||||
for start in range(0, inputs["input_values"].shape[-1], frame_size):
|
||||
input_values_chunk = inputs["input_values"][:, :, start : start + frame_size]
|
||||
encoder_outputs = model.encode(
|
||||
input_values_chunk,
|
||||
padding_cache=padding_cache,
|
||||
encoder_past_key_values=encoder_past_key_values,
|
||||
use_streaming=True,
|
||||
)
|
||||
encoder_past_key_values = encoder_outputs.encoder_past_key_values
|
||||
padding_cache = encoder_outputs.padding_cache
|
||||
encoded_frames_list.append(encoder_outputs.audio_codes)
|
||||
|
||||
streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1)
|
||||
|
||||
torch.testing.assert_close(streamed_audio_codes, audio_codes)
|
||||
|
||||
def test_integration(self):
|
||||
expected_rmses = {
|
||||
"8": 0.0018785292,
|
||||
|
||||
Reference in New Issue
Block a user