Add kyutai stt (#38909)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* first draft

* cleaner version

* udpate tests + modeling

* add tests

* init

* udpate test_modeling_common

* fix tests

* csm Processor draft

* convertion update

* mimi cache padding convolutions draft

* mimi streaming udpates

* update mimi padding cache test

* udpate cache padding mimi test

* make style mimi

* updates generate moshi asr

* moshi asr integration tests (single + batched)

* update tests

* update conversion script

* good default sliding window value

* udpdate generate

* update test checkpoint

* nit

* fix mimi

* fix codec prefix

* revert

* revert

* update config

* update config

* unnecessary mimi input restriction

* remove delay in tokens

* remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask

* test update

* modular update

* make style

* nit

* rename

* create codec model generation config at init

* remove delay

* max_new_tokens/length warning

* correct conv1 padding cache import for modular

* nit

* fix on encoder_past_key_values

* convert modular

* move frame_size to config

* move frame_size to config

* update test name

* handle first token is bos

* better handling of max_new_tokens

* fix

* fix batch size in test input prep

* update docstring

* convert modular

* make style

* make style

* add feature extractor

* correct modular convention name for feature_extraction file

* update convertion script

* doc processor

* update doc

* udpate init

* update model type

* fixes

* update tests

* fix

* make

* add doc

* nit

* fix

* doc

* auto mappings

* doc

* nit

* convert modular

* doc

* nit

* extend _keep_in_fp32_modules to enforce fp32

* renaming to stt

* doc update + test update

* doc fixes

* doc fix

* doc fix

* fix musicgen tests

* fix musicgen tests

* make style

* fix musicgen tests

* correct frame_rate config param for mimi

* update mimi test

* revert update mimi test

* enforce cpu test

* move cache init in cache class

* convert modular

* docstring update

* update model id

* feature_extractor -> feature_extraction (SEW)

* convert modular

* update model id
This commit is contained in:
eustlb
2025-06-24 18:01:15 +02:00
committed by GitHub
parent 08bf7f1afe
commit 6bdd4ec952
23 changed files with 4000 additions and 200 deletions

View File

@@ -0,0 +1,704 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Moshi ASR model."""
import gc
import inspect
import tempfile
import unittest
import datasets
import pytest
from parameterized import parameterized
from transformers import (
KyutaiSpeechToTextConfig,
KyutaiSpeechToTextForConditionalGeneration,
KyutaiSpeechToTextProcessor,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
KyutaiSpeechToTextForConditionalGeneration,
KyutaiSpeechToTextModel,
)
class KyutaiSpeechToTextModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
text_seq_length=1,
input_values_length=192, # gives 3 audio tokens, corresponding to the default in GenerationTesterMixin
is_training=False,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
codebook_vocab_size=2049,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=None,
max_position_embeddings=512,
rope_theta=10000.0,
hidden_act="silu",
head_dim=None,
initializer_range=0.02,
use_cache=True,
sliding_window=512,
attention_dropout=0.1,
ffn_dim=38,
rms_norm_eps=1e-6,
num_codebooks=8,
frame_size=64,
delay_in_tokens=5,
audio_bos_token_id=2048,
audio_pad_token_id=2048,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
codec_config={
"model_type": "mimi",
"num_quantizers": 8,
"audio_channels": 1,
"chunk_in_sec": None,
"hidden_size": 16,
"num_filters": 8,
"num_residual_layers": 1,
"upsampling_ratios": [8, 4],
"codebook_size": 16,
"vector_quantization_hidden_dimension": 16,
"upsample_groups": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"sliding_window": 4,
"codebook_dim": 16,
"use_cache": False,
},
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.text_seq_length = text_seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.codebook_vocab_size = codebook_vocab_size
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.hidden_act = hidden_act
self.head_dim = head_dim
self.initializer_range = initializer_range
self.use_cache = use_cache
self.sliding_window = sliding_window
self.attention_dropout = attention_dropout
self.ffn_dim = ffn_dim
self.rms_norm_eps = rms_norm_eps
self.num_codebooks = num_codebooks
self.frame_size = frame_size
self.delay_in_tokens = delay_in_tokens
self.audio_bos_token_id = audio_bos_token_id
self.audio_pad_token_id = audio_pad_token_id
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.codec_config = codec_config
self.scope = scope
self.input_values_length = input_values_length
def get_config(self):
return KyutaiSpeechToTextConfig(
codebook_vocab_size=self.codebook_vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
max_position_embeddings=self.max_position_embeddings,
rope_theta=self.rope_theta,
hidden_act=self.hidden_act,
head_dim=self.head_dim,
initializer_range=self.initializer_range,
use_cache=self.use_cache,
sliding_window=self.sliding_window,
attention_dropout=self.attention_dropout,
ffn_dim=self.ffn_dim,
rms_norm_eps=self.rms_norm_eps,
num_codebooks=self.num_codebooks,
frame_size=self.frame_size,
delay_in_tokens=self.delay_in_tokens,
audio_bos_token_id=self.audio_bos_token_id,
audio_pad_token_id=self.audio_pad_token_id,
tie_word_embeddings=self.tie_word_embeddings,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
codec_config=self.codec_config,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = KyutaiSpeechToTextModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def prepare_config_and_inputs(self):
config = self.get_config()
text_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1
codebook_input_ids = (
ids_tensor([self.batch_size, self.seq_length, self.num_codebooks], self.codebook_vocab_size - 1) + 1
)
input_ids = torch.cat([text_input_ids.unsqueeze(2), codebook_input_ids], dim=2)
attention_mask = text_input_ids.ne(1).to(torch_device)
return config, input_ids, attention_mask
def prepare_config_and_inputs_generate(self):
config = self.get_config()
input_ids = torch.ones([self.batch_size, 1], dtype=torch.long, device=torch_device)
input_values = floats_tensor([self.batch_size, 1, self.input_values_length])
padding_mask = torch.ones_like(input_values, dtype=torch.int32, device=torch_device)
return config, input_ids, input_values, padding_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_common_generate(self):
config_and_inputs = self.prepare_config_and_inputs_generate()
(
config,
input_ids,
input_values,
padding_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"input_values": input_values,
"padding_mask": padding_mask,
}
return config, inputs_dict
@require_torch
class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
KyutaiSpeechToTextModel,
KyutaiSpeechToTextForConditionalGeneration,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": KyutaiSpeechToTextModel,
"automatic-speech-recognition": KyutaiSpeechToTextForConditionalGeneration,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
def setUp(self):
self.model_tester = KyutaiSpeechToTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=KyutaiSpeechToTextConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels)
return inputs_dict
def prepare_config_and_inputs_for_generate(self, batch_size=2):
# monkey patch prepare_config_and_inputs_for_common
prepare_config_and_inputs_for_common = self.model_tester.prepare_config_and_inputs_for_common
original_batch_size = self.model_tester.batch_size
self.model_tester.prepare_config_and_inputs_for_common = (
self.model_tester.prepare_config_and_inputs_for_common_generate
)
self.model_tester.batch_size = batch_size
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
self.model_tester.prepare_config_and_inputs_for_common = prepare_config_and_inputs_for_common
self.model_tester.batch_size = original_batch_size
return config, filtered_inputs_dict
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_model_get_set_embeddings(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_tie_model_weights(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_resize_embeddings_untied(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_resize_tokens_embeddings(self):
pass
@pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
def test_tied_weights_keys(self):
pass
@pytest.mark.skip(reason="Does not apply to Moshi ASR that requires input_values.")
def test_generate_without_input_ids(self):
pass
def test_initialization(self):
"""
Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv", "input_proj", "output_proj"]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
self.skipTest("Test is failing, fix me :) ")
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
parent_parameterized_test(self)
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_disk_offload_bin(self):
pass
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
def test_disk_offload_safetensors(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
# First, filter out models that don't support left padding
# - The model must have generative capabilities
if len(self.all_generative_model_classes) == 0:
self.skipTest(reason="No generative architecture available for this model.")
# - The model must support padding
if not self.has_attentions:
self.skipTest(reason="This model doesn't support padding.")
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _ = self.prepare_config_and_inputs_for_generate()
if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict().keys()
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
)
if has_encoder_attributes:
self.skipTest(
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)
# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in decoder_only_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat(
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
_, inputs = self.prepare_config_and_inputs_for_generate()
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
# Let's generate again, but passing the past key values in between (2 + 1 = 3 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=2)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]
if config.is_encoder_decoder:
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
else:
inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
# needs to be overridden to avoid to avoid casting of input_values to float16
# indeed, the codec model is kept in fp32, so we need to avoid casting input_values to float16
def _test_attention_implementation(self, attn_implementation):
"""
Compares the output of generate with the eager attention implementation against other implementations.
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
this separate function.
"""
max_new_tokens = 30
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
}
for model_class in self.all_generative_model_classes:
if not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {}
for input_name, input_data in original_inputs_dict.items():
if (
isinstance(input_data, torch.Tensor)
and input_data.dtype in [torch.float32, torch.bfloat16]
and input_name != "input_values"
):
inputs_dict[input_name] = input_data.to(torch.float16)
else:
inputs_dict[input_name] = input_data
main_input = inputs_dict[model_class.main_input_name]
# FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
# attention masks at test time and, with generate, the mask will be appended with 1s on the right,
# resulting in a mask with holes (not supported properly by FA2).
if attn_implementation == "flash_attention_2":
for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
if input_name in inputs_dict:
inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
del model
gc.collect()
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
"use_cache": True,
}
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="eager",
).to(torch_device)
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
del model_eager
gc.collect()
model_attn = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation=attn_implementation,
).to(torch_device)
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
del model_attn
gc.collect()
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
_dataset = None
def setUp(self):
self.model_checkpoint = "kyutai/stt-2.6b-en"
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def _load_dataset(cls):
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
if cls._dataset is None:
cls._dataset = datasets.load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
# using 24000 here for simplicity, should rather be processor.feature_extractor.sampling_rate
cls._dataset = cls._dataset.cast_column("audio", datasets.Audio(sampling_rate=24000))
def _load_datasamples(self, num_samples):
self._load_dataset()
ds = self._dataset
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow
@require_torch_accelerator
def test_generation(self):
"""
reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/7a9aa6139d11e0103c6b65bac103da52
DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
ultimately giving different outputs.
"""
processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device
)
samples = self._load_datasamples(1)
inputs = processor(
samples,
).to(torch_device)
out = model.generate(**inputs)
# fmt: off
EXPECTED_TOKENS = torch.tensor([
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
)
# fmt: on
torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)
@slow
@require_torch_accelerator
def test_generation_batched(self):
"""
reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/b58c217c75124d405ec1c13877c7ece8
DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
ultimately giving different outputs.
"""
processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
self.model_checkpoint, device_map=torch_device
)
samples = self._load_datasamples(4)
inputs = processor(
samples,
).to(torch_device)
out = model.generate(**inputs)
# fmt: off
EXPECTED_TOKENS = torch.tensor([
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
])
# fmt: on
torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)

View File

@@ -107,14 +107,21 @@ class MimiModelTester:
self.sliding_window = sliding_window
self.use_cache = use_cache
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
def prepare_config_and_inputs(self, input_values_length=None):
input_values = floats_tensor(
[
self.batch_size,
self.num_channels,
self.intermediate_size if input_values_length is None else input_values_length,
],
scale=1.0,
)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
def prepare_config_and_inputs_for_common(self, input_values_length=None):
config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length)
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
@@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase):
)
self.assertTrue(rmse < 1e-3)
def test_integration_encode_with_padding_cache(self):
"""
We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk.
1. we encode a first time the entire audio
2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size)
This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures.
"""
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_id = "kyutai/mimi"
model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu")
processor = AutoFeatureExtractor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[-1]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to("cpu")
frame_size = model.config.frame_size
audio_codes = model.encode(inputs["input_values"]).audio_codes
# streaming chunk by chunk
encoder_past_key_values = None
padding_cache = None
encoded_frames_list = []
for start in range(0, inputs["input_values"].shape[-1], frame_size):
input_values_chunk = inputs["input_values"][:, :, start : start + frame_size]
encoder_outputs = model.encode(
input_values_chunk,
padding_cache=padding_cache,
encoder_past_key_values=encoder_past_key_values,
use_streaming=True,
)
encoder_past_key_values = encoder_outputs.encoder_past_key_values
padding_cache = encoder_outputs.padding_cache
encoded_frames_list.append(encoder_outputs.audio_codes)
streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1)
torch.testing.assert_close(streamed_audio_codes, audio_codes)
def test_integration(self):
expected_rmses = {
"8": 0.0018785292,

View File

@@ -3566,7 +3566,11 @@ class ModelTesterMixin:
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
if (
model.__class__.__name__.startswith("Musicgen")
and hasattr(self.model_tester, "num_codebooks")
and not hasattr(model_eager, "text_encoder")
):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
@@ -3626,7 +3630,7 @@ class ModelTesterMixin:
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks"):
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size