From 6bdd4ec95264e5d8f219cfe4ee29ea9b42474bb7 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Tue, 24 Jun 2025 18:01:15 +0200 Subject: [PATCH] Add kyutai stt (#38909) * first draft * cleaner version * udpate tests + modeling * add tests * init * udpate test_modeling_common * fix tests * csm Processor draft * convertion update * mimi cache padding convolutions draft * mimi streaming udpates * update mimi padding cache test * udpate cache padding mimi test * make style mimi * updates generate moshi asr * moshi asr integration tests (single + batched) * update tests * update conversion script * good default sliding window value * udpdate generate * update test checkpoint * nit * fix mimi * fix codec prefix * revert * revert * update config * update config * unnecessary mimi input restriction * remove delay in tokens * remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask * test update * modular update * make style * nit * rename * create codec model generation config at init * remove delay * max_new_tokens/length warning * correct conv1 padding cache import for modular * nit * fix on encoder_past_key_values * convert modular * move frame_size to config * move frame_size to config * update test name * handle first token is bos * better handling of max_new_tokens * fix * fix batch size in test input prep * update docstring * convert modular * make style * make style * add feature extractor * correct modular convention name for feature_extraction file * update convertion script * doc processor * update doc * udpate init * update model type * fixes * update tests * fix * make * add doc * nit * fix * doc * auto mappings * doc * nit * convert modular * doc * nit * extend _keep_in_fp32_modules to enforce fp32 * renaming to stt * doc update + test update * doc fixes * doc fix * doc fix * fix musicgen tests * fix musicgen tests * make style * fix musicgen tests * correct frame_rate config param for mimi * update mimi test * revert update mimi test * enforce cpu test * move cache init in cache class * convert modular * docstring update * update model id * feature_extractor -> feature_extraction (SEW) * convert modular * update model id --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/stt.md | 122 ++ src/transformers/modeling_utils.py | 5 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/mimi/configuration_mimi.py | 50 +- src/transformers/models/mimi/modeling_mimi.py | 356 ++-- ...actor_sew.py => feature_extraction_sew.py} | 0 src/transformers/models/stt/__init__.py | 29 + .../configuration_kyutai_speech_to_text.py | 188 +++ .../convert_kyutai_speech_to_text_to_hf.py | 377 +++++ ...eature_extraction_kyutai_speech_to_text.py | 237 +++ .../stt/modeling_kyutai_speech_to_text.py | 1434 +++++++++++++++++ .../stt/modular_kyutai_speech_to_text.py | 510 ++++++ .../stt/processing_kyutai_speech_to_text.py | 104 ++ .../models/kyutai_speech_to_text/__init__.py | 0 .../test_modeling_kyutai_speech_to_text.py | 704 ++++++++ tests/models/mimi/test_modeling_mimi.py | 63 +- tests/test_modeling_common.py | 8 +- utils/modular_model_converter.py | 4 +- 23 files changed, 4000 insertions(+), 200 deletions(-) create mode 100644 docs/source/en/model_doc/stt.md rename src/transformers/models/sew/{feature_extractor_sew.py => feature_extraction_sew.py} (100%) create mode 100644 src/transformers/models/stt/__init__.py create mode 100644 src/transformers/models/stt/configuration_kyutai_speech_to_text.py create mode 100644 src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py create mode 100644 src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py create mode 100644 src/transformers/models/stt/modeling_kyutai_speech_to_text.py create mode 100644 src/transformers/models/stt/modular_kyutai_speech_to_text.py create mode 100644 src/transformers/models/stt/processing_kyutai_speech_to_text.py create mode 100644 tests/models/kyutai_speech_to_text/__init__.py create mode 100644 tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6ebe8044ad..d8438a4165 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -843,6 +843,8 @@ title: GraniteSpeech - local: model_doc/hubert title: Hubert + - local: model_doc/stt + title: Kyutai Speech-To-Text - local: model_doc/mctct title: MCTCT - local: model_doc/mimi diff --git a/docs/source/en/model_doc/stt.md b/docs/source/en/model_doc/stt.md new file mode 100644 index 0000000000..02428899df --- /dev/null +++ b/docs/source/en/model_doc/stt.md @@ -0,0 +1,122 @@ + + +# Kyutai Speech-To-Text +## Overview + +Kyutai STT is a speech-to-text model architecture based on the [Mimi codec](https://huggingface.co/docs/transformers/en/model_doc/mimi), which encodes audio into discrete tokens in a streaming fashion, and a [Moshi-like](https://huggingface.co/docs/transformers/en/model_doc/moshi) autoregressive decoder. Kyutai’s lab has released two model checkpoints: +- [kyutai/stt-1b-en_fr](https://huggingface.co/kyutai/stt-1b-en_fr): a 1B-parameter model capable of transcribing both English and French +- [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en): a 2.6B-parameter model focused solely on English, optimized for maximum transcription accuracy + +
+ +
+ +## Usage Tips + +### Inference + +```python +import torch +from datasets import load_dataset, Audio +from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration + +# 1. load the model and the processor +torch_device = "cuda" if torch.cuda.is_available() else "cpu" +model_id = "kyutai/stt-2.6b-en" + +processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) +model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + +# 2. load audio samples +ds = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" +) +ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + +# 3. prepare the model inputs +inputs = processor( + ds[0]["audio"]["array"], +) +inputs.to(torch_device) + +# 4. infer the model +output_tokens = model.generate(**inputs) + +# 5. decode the generated tokens +print(processor.batch_decode(output_tokens, skip_special_tokens=True)) +``` + +### Batched Inference + +```python +import torch +from datasets import load_dataset, Audio +from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration + +# 1. load the model and the processor +torch_device = "cuda" if torch.cuda.is_available() else "cpu" +model_id = "kyutai/stt-2.6b-en" + +processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) +model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + +# 2. load audio samples +ds = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" +) +ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + +# 3. prepare the model inputs +audio_arrays = [ds[i]["audio"]["array"] for i in range(4)] +inputs = processor(audio_arrays, return_tensors="pt", padding=True) +inputs = inputs.to(torch_device) + +# 4. infer the model +output_tokens = model.generate(**inputs) + +# 5. decode the generated tokens +decoded_outputs = processor.batch_decode(output_tokens, skip_special_tokens=True) +for output in decoded_outputs: + print(output) +``` + +This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb). +The original code can be found [here](https://github.com/kyutai-labs/moshi). + + +## KyutaiSpeechToTextConfig + +[[autodoc]] KyutaiSpeechToTextConfig + +## KyutaiSpeechToTextProcessor + +[[autodoc]] KyutaiSpeechToTextProcessor + - __call__ + +## KyutaiSpeechToTextFeatureExtractor + +[[autodoc]] KyutaiSpeechToTextFeatureExtractor + +## KyutaiSpeechToTextForConditionalGeneration + +[[autodoc]] KyutaiSpeechToTextForConditionalGeneration + - forward + - generate + +## KyutaiSpeechToTextModel + +[[autodoc]] KyutaiSpeechToTextModel diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4774a72df7..4f6095a3ed 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4658,8 +4658,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details. + # Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32 if model._keep_in_fp32_modules is not None and ( - torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) + torch_dtype == torch.float16 + or torch_dtype == torch.bfloat16 + or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) ): # We need to match exact layers, so we add either `.` on each side, or start/end of string keep_in_fp32_regex = re.compile( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 504fcc2684..8d36068353 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -285,6 +285,7 @@ if TYPE_CHECKING: from .squeezebert import * from .stablelm import * from .starcoder2 import * + from .stt import * from .superglue import * from .superpoint import * from .swiftformer import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d7529b2b63..54a285e3c6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -322,6 +322,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("squeezebert", "SqueezeBertConfig"), ("stablelm", "StableLmConfig"), ("starcoder2", "Starcoder2Config"), + ("stt", "KyutaiSpeechToTextConfig"), ("superglue", "SuperGlueConfig"), ("superpoint", "SuperPointConfig"), ("swiftformer", "SwiftFormerConfig"), @@ -707,6 +708,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("squeezebert", "SqueezeBERT"), ("stablelm", "StableLm"), ("starcoder2", "Starcoder2"), + ("stt", "KyutaiSpeechToText"), ("superglue", "SuperGlue"), ("superpoint", "SuperPoint"), ("swiftformer", "SwiftFormer"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index e7db1944d3..5754b3bc1b 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -91,6 +91,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("sew-d", "Wav2Vec2FeatureExtractor"), ("speech_to_text", "Speech2TextFeatureExtractor"), ("speecht5", "SpeechT5FeatureExtractor"), + ("stt", "KyutaiSpeechToTextFeatureExtractor"), ("swiftformer", "ViTFeatureExtractor"), ("swin", "ViTFeatureExtractor"), ("swinv2", "ViTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b3224b7d46..cbfc0f7647 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -300,6 +300,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("squeezebert", "SqueezeBertModel"), ("stablelm", "StableLmModel"), ("starcoder2", "Starcoder2Model"), + ("stt", "KyutaiSpeechToTextModel"), ("superglue", "SuperGlueForKeypointMatching"), ("swiftformer", "SwiftFormerModel"), ("swin", "SwinModel"), @@ -1055,6 +1056,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("speecht5", "SpeechT5ForSpeechToText"), + ("stt", "KyutaiSpeechToTextForConditionalGeneration"), ("whisper", "WhisperForConditionalGeneration"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index b2e36bc4bc..478766e6ee 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -116,6 +116,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("speech_to_text", "Speech2TextProcessor"), ("speech_to_text_2", "Speech2Text2Processor"), ("speecht5", "SpeechT5Processor"), + ("stt", "KyutaiSpeechToTextProcessor"), ("trocr", "TrOCRProcessor"), ("tvlt", "TvltProcessor"), ("tvp", "TvpProcessor"), diff --git a/src/transformers/models/mimi/configuration_mimi.py b/src/transformers/models/mimi/configuration_mimi.py index a36b5e7101..b213359886 100644 --- a/src/transformers/models/mimi/configuration_mimi.py +++ b/src/transformers/models/mimi/configuration_mimi.py @@ -38,8 +38,8 @@ class MimiConfig(PretrainedConfig): Args: sampling_rate (`int`, *optional*, defaults to 24000): The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). - frame_rate (`float`, *optional*, defaults to 12.5): - Framerate of the model. + frame_rate (`float`, *optional*): + Should be computed from the other parameters, yet kept for backward compatibility. audio_channels (`int`, *optional*, defaults to 1): Number of channels in the audio data. Either 1 for mono or 2 for stereo. hidden_size (`int`, *optional*, defaults to 512): @@ -111,6 +111,8 @@ class MimiConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `False`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. + use_streaming (`bool`, *optional*, defaults to `False`): + Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. sliding_window (`int`, *optional*, defaults to 250): @@ -141,7 +143,7 @@ class MimiConfig(PretrainedConfig): def __init__( self, sampling_rate=24_000, - frame_rate=12.5, + frame_rate=None, audio_channels=1, hidden_size=512, num_filters=64, @@ -172,6 +174,7 @@ class MimiConfig(PretrainedConfig): initializer_range=0.02, norm_eps=1e-5, use_cache=False, + use_streaming=False, rope_theta=10000.0, sliding_window=250, attention_dropout=0.0, @@ -180,7 +183,6 @@ class MimiConfig(PretrainedConfig): **kwargs, ): self.sampling_rate = sampling_rate - self.frame_rate = frame_rate self.audio_channels = audio_channels self.hidden_size = hidden_size self.num_filters = num_filters @@ -209,6 +211,7 @@ class MimiConfig(PretrainedConfig): self.initializer_range = initializer_range self.norm_eps = norm_eps self.use_cache = use_cache + self.use_streaming = use_streaming self.rope_theta = rope_theta self.sliding_window = sliding_window self.attention_dropout = attention_dropout @@ -216,6 +219,14 @@ class MimiConfig(PretrainedConfig): self.layer_scale_initial_scale = layer_scale_initial_scale self.attention_bias = attention_bias + # Handle backward compatibility for frame_rate: + # If frame_rate is explicitly provided, use it (backward compatibility) + # Otherwise, compute it from other parameters (correctly) + if frame_rate is not None: + self._frame_rate = frame_rate + else: + self._frame_rate = None + if num_semantic_quantizers >= self.num_quantizers: raise ValueError( f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}." @@ -233,5 +244,36 @@ class MimiConfig(PretrainedConfig): # alias to num_quantizers return self.num_quantizers + @property + def frame_size(self) -> int: + # 1. we need each encoder conv stride + # first conv + strides = [1] + + # layer convs + for ratio in reversed(self.upsampling_ratios): + for j in range(self.num_residual_layers): + len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1 + strides.extend([1] * (len_kernel_sizes + 1)) + if self.use_conv_shortcut: # skip connection + strides.append(1) + + strides.append(ratio) + + # last conv + strides.append(1) + + # downsampling layer + strides.append(2) + + return math.prod(strides) + + @property + def frame_rate(self) -> float: + # handle backward compatibility + if self._frame_rate is not None: + return self._frame_rate + return self.sampling_rate / self.frame_size + __all__ = ["MimiConfig"] diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index f1363f7897..221388f858 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -23,25 +23,20 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging +from ...utils import ModelOutput, auto_docstring, logging from .configuration_mimi import MimiConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - logger = logging.get_logger(__name__) @@ -78,6 +73,91 @@ class MimiOutput(ModelOutput): decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None +class MimiConv1dPaddingCache: + """ + Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding. + See: https://arxiv.org/pdf/2005.06720 & https://arxiv.org/pdf/2204.07064 + + A padding cache is a list of cached partial hidden states for each convolution layer. + Hidden states are cached from the previous call to the MimiConv1d forward pass, given the padding size. + """ + + def __init__( + self, + num_layers: int, + per_layer_padding: list[int], + per_layer_padding_mode: list[str], + per_layer_in_channels: list[int], + ): + # ensure correct number of layers for each arg + from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)} + + if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers: + raise ValueError( + f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`" + ) + elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode): + raise NotImplementedError( + "`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode" + ) + + self.per_layer_padding = per_layer_padding + self.per_layer_padding_mode = per_layer_padding_mode + self.per_layer_in_channels = per_layer_in_channels + self.per_layer_is_init = [True] * num_layers + + self.padding_cache = [None] * num_layers + + def update(self, hidden_states: torch.Tensor, layer_idx: int): + """ + Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache. + + Parameters: + hidden_states (`torch.Tensor`): + The hidden states to be partially cached. + layer_idx (`int`): + The index of the layer to cache the states for. + Returns: + `torch.Tensor` or `None`, the current padding cache. + """ + batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device + padding = self.per_layer_padding[layer_idx] + padding_mode = self.per_layer_padding_mode[layer_idx] + in_channels = self.per_layer_in_channels[layer_idx] + + if self.padding_cache[layer_idx] is None: + if padding_mode == "constant": + current_cache = torch.zeros( + batch_size, + in_channels, + padding, + device=device, + dtype=dtype, + ) + elif padding_mode == "replicate": + current_cache = ( + torch.ones( + batch_size, + in_channels, + padding, + device=device, + dtype=dtype, + ) + * hidden_states[..., :1] + ) + else: + current_cache = self.padding_cache[layer_idx] + + # update the cache + if padding > 0: + padding_states = hidden_states[:, :, -padding:] + else: + padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device) + self.padding_cache[layer_idx] = padding_states + + return current_cache + + @dataclass @auto_docstring class MimiEncoderOutput(ModelOutput): @@ -96,6 +176,7 @@ class MimiEncoderOutput(ModelOutput): audio_codes: Optional[torch.LongTensor] = None encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None + padding_cache: Optional[MimiConv1dPaddingCache] = None @dataclass @@ -130,12 +211,15 @@ class MimiConv1d(nn.Module): stride: int = 1, dilation: int = 1, groups: int = 1, - pad_mode=None, + pad_mode: Optional[str] = None, bias: bool = True, + layer_idx: Optional[int] = None, ): super().__init__() self.causal = config.use_causal_conv self.pad_mode = config.pad_mode if pad_mode is None else pad_mode + self.layer_idx = layer_idx + self.in_channels = in_channels # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: @@ -232,12 +316,20 @@ class MimiConv1d(nn.Module): ) // self.conv.stride[0] + 1 return output_lenght - def forward(self, hidden_states): + def forward(self, hidden_states, padding_cache=None): extra_padding = self._get_extra_padding_for_conv1d(hidden_states) - if self.causal: + if not self.causal and padding_cache is not None: + raise ValueError("`padding_cache` is not supported for non-causal convolutions.") + + if self.causal and padding_cache is not None: + layer_padding_cache = padding_cache.update(hidden_states, self.layer_idx) + hidden_states = torch.cat([layer_padding_cache, hidden_states], dim=2) + + elif self.causal: # Left padding for causal hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode) + else: hidden_states = self._pad1d( hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode @@ -305,7 +397,6 @@ class MimiConvTranspose1d(nn.Module): return hidden_states -# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi class MimiResnetBlock(nn.Module): """ Residual block from SEANet model as used by Mimi. @@ -331,12 +422,21 @@ class MimiResnetBlock(nn.Module): else: self.shortcut = nn.Identity() - def forward(self, hidden_states): + def forward(self, hidden_states, padding_cache=None): residual = hidden_states - for layer in self.block: - hidden_states = layer(hidden_states) - return self.shortcut(residual) + hidden_states + for layer in self.block: + if isinstance(layer, MimiConv1d): + hidden_states = layer(hidden_states, padding_cache=padding_cache) + else: + hidden_states = layer(hidden_states) + + if isinstance(self.shortcut, MimiConv1d): + residual = self.shortcut(residual, padding_cache=padding_cache) + else: + residual = self.shortcut(residual) + + return residual + hidden_states class MimiEncoder(nn.Module): @@ -370,10 +470,17 @@ class MimiEncoder(nn.Module): self.layers = nn.ModuleList(model) self._mimiconv1d_layer_names = mimiconv1d_layer_names - # Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward - def forward(self, hidden_states): + # initialize layer_idx for MimiConv1d submodules, necessary for padding_cache + for layer_idx, layername in enumerate(self._mimiconv1d_layer_names): + conv_layer = self.get_submodule(layername) + setattr(conv_layer, "layer_idx", layer_idx) + + def forward(self, hidden_states, padding_cache=None): for layer in self.layers: - hidden_states = layer(hidden_states) + if isinstance(layer, (MimiConv1d, MimiResnetBlock)): + hidden_states = layer(hidden_states, padding_cache=padding_cache) + else: + hidden_states = layer(hidden_states) return hidden_states @@ -1005,11 +1112,13 @@ class MimiTransformerModel(nn.Module): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = None - if attention_mask is not None: - causal_mask = self._update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_values, output_attentions - ) + causal_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -1054,163 +1163,6 @@ class MimiTransformerModel(nn.Module): attentions=all_self_attns, ) - # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mimi. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: MimiConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`MimiConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class MimiDecoder(nn.Module): """SEANet decoder as used by Mimi.""" @@ -1269,7 +1221,7 @@ class MimiEuclideanCodebook(nn.Module): def quantize(self, hidden_states): # Projects each vector in `hidden_states` over the nearest centroid and return its index. # `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension. - dists = torch.cdist(hidden_states[None], self.embed[None], p=2)[0] + dists = torch.cdist(hidden_states[None].float(), self.embed[None].float(), p=2)[0] embed_ind = dists.argmin(dim=-1) return embed_ind @@ -1476,6 +1428,7 @@ class MimiModel(MimiPreTrainedModel): stride=2, bias=False, pad_mode="replicate", + layer_idx=len(self.encoder._mimiconv1d_layer_names), ) self.upsample = MimiConvTranspose1d( @@ -1512,12 +1465,17 @@ class MimiModel(MimiPreTrainedModel): num_quantizers: int, padding_mask: int, past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + padding_cache: Optional[MimiConv1dPaddingCache] = None, return_dict: Optional[bool] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale. """ - embeddings = self.encoder(input_values) + + # TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported. + embeddings = self.encoder(input_values, padding_cache=padding_cache) + + # TODO: @eustlb, convert the padding mask to attention mask. encoder_outputs = self.encoder_transformer( embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict ) @@ -1526,11 +1484,11 @@ class MimiModel(MimiPreTrainedModel): elif len(encoder_outputs) > 1: past_key_values = encoder_outputs[1] embeddings = encoder_outputs[0].transpose(1, 2) - embeddings = self.downsample(embeddings) + embeddings = self.downsample(embeddings, padding_cache=padding_cache) codes = self.quantizer.encode(embeddings, num_quantizers) codes = codes.transpose(0, 1) - return codes, past_key_values + return codes, past_key_values, padding_cache def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor: """ @@ -1570,6 +1528,8 @@ class MimiModel(MimiPreTrainedModel): padding_mask: Optional[torch.Tensor] = None, num_quantizers: Optional[float] = None, encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + padding_cache: Optional[MimiConv1dPaddingCache] = None, + use_streaming: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], MimiEncoderOutput]: """ @@ -1598,6 +1558,7 @@ class MimiModel(MimiPreTrainedModel): `codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform. """ return_dict = return_dict if return_dict is not None else self.config.return_dict + use_streaming = use_streaming if use_streaming is not None else self.config.use_streaming num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers @@ -1614,11 +1575,31 @@ class MimiModel(MimiPreTrainedModel): if padding_mask is None: padding_mask = torch.ones_like(input_values).bool() - encoded_frames, encoder_past_key_values = self._encode_frame( + if use_streaming and padding_cache is None: + per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], [] + for layer_name in self.encoder._mimiconv1d_layer_names: + per_layer_padding.append(self.encoder.get_submodule(layer_name).padding_total) + per_layer_padding_mode.append(self.encoder.get_submodule(layer_name).pad_mode) + per_layer_in_channels.append(self.encoder.get_submodule(layer_name).in_channels) + + # downsample layer + per_layer_padding.append(self.downsample.padding_total) + per_layer_padding_mode.append(self.downsample.pad_mode) + per_layer_in_channels.append(self.downsample.in_channels) + + padding_cache = MimiConv1dPaddingCache( + num_layers=len(self.encoder._mimiconv1d_layer_names) + 1, + per_layer_padding=per_layer_padding, + per_layer_padding_mode=per_layer_padding_mode, + per_layer_in_channels=per_layer_in_channels, + ) + + encoded_frames, encoder_past_key_values, padding_cache = self._encode_frame( input_values, num_quantizers, padding_mask.bool(), past_key_values=encoder_past_key_values, + padding_cache=padding_cache, return_dict=return_dict, ) @@ -1626,9 +1607,10 @@ class MimiModel(MimiPreTrainedModel): return ( encoded_frames, encoder_past_key_values, + padding_cache, ) - return MimiEncoderOutput(encoded_frames, encoder_past_key_values) + return MimiEncoderOutput(encoded_frames, encoder_past_key_values, padding_cache) def _decode_frame( self, diff --git a/src/transformers/models/sew/feature_extractor_sew.py b/src/transformers/models/sew/feature_extraction_sew.py similarity index 100% rename from src/transformers/models/sew/feature_extractor_sew.py rename to src/transformers/models/sew/feature_extraction_sew.py diff --git a/src/transformers/models/stt/__init__.py b/src/transformers/models/stt/__init__.py new file mode 100644 index 0000000000..5823883c6c --- /dev/null +++ b/src/transformers/models/stt/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_kyutai_speech_to_text import * + from .feature_extraction_kyutai_speech_to_text import * + from .modeling_kyutai_speech_to_text import * + from .processing_kyutai_speech_to_text import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/stt/configuration_kyutai_speech_to_text.py b/src/transformers/models/stt/configuration_kyutai_speech_to_text.py new file mode 100644 index 0000000000..f9ea11a5f4 --- /dev/null +++ b/src/transformers/models/stt/configuration_kyutai_speech_to_text.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.s + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`KyutaiSpeechToTextForConditionalGeneration`]. + It is used to instantiate a Kyutai Speech-to-Text model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + 2.6b-en model. + + e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + codebook_vocab_size (`int`, *optional*, defaults to 2049): + Vocabulary size of the codebook. Defines the number of different audio tokens that can be represented by each codebook. + vocab_size (`int`, *optional*, defaults to 4001): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `input_ids` passed when calling the model. + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the layers and the pooler layer of the main decoder. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of decoder layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the main decoder block. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. + max_position_embeddings (`int`, *optional*, defaults to 750): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + rope_theta (`float`, *optional*, defaults to 100000.0): + The base period of the RoPE embeddings. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + sliding_window (`int`, *optional*, defaults to 375): + Sliding window attention window size. If not specified, will default to `3000`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ffn_dim (`int`, *optional*, defaults to 11264): + Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even. + rms_norm_eps (`float`, *optional*, defaults to 1e-08): + The epsilon used by the rms normalization layers. + num_codebooks (`int`, *optional*, defaults to 32): + The number of audio codebooks for each audio channels. + audio_bos_token_id (`int`, *optional*, defaults to 2048): + Beginning of stream token id for codebook tokens. + audio_pad_token_id (`int`, *optional*, defaults to 69569): + Padding token id for codebook tokens. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + pad_token_id (`int`, *optional*, defaults to 3): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 48000): + Beginning of stream token id for text tokens. + codec_config (`PretrainedConfig`, *optional*): + Configuration for the codec. + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + - **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the depth decoder config. + + + Example: + ```python + >>> from transformers import KyutaiSpeechToTextConfig, KyutaiSpeechToTextForConditionalGeneration + + >>> # Initializing a KyutaiSpeechToTextConfig + >>> configuration = KyutaiSpeechToTextConfig() + + >>> # Initializing a model + >>> model = KyutaiSpeechToTextForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + # not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify + model_type = "stt" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"codec_config": AutoConfig} + + def __init__( + self, + codebook_vocab_size=2049, + vocab_size=4001, + hidden_size=2048, + num_hidden_layers=48, + num_attention_heads=32, + num_key_value_heads=None, + max_position_embeddings=750, + rope_theta=100000.0, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=375, + attention_dropout=0.0, + ffn_dim=11264, + rms_norm_eps=1e-8, + num_codebooks=32, + audio_bos_token_id=2048, + audio_pad_token_id=69569, + tie_word_embeddings=False, + pad_token_id=3, + bos_token_id=48000, + codec_config=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, bos_token_id=bos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs + ) + + if codec_config is None: + self.codec_config = AutoConfig.for_model("mimi") + logger.info("codec_config is None, using default audio encoder config.") + elif isinstance(codec_config, dict): + self.codec_config = AutoConfig.for_model(**codec_config) + elif isinstance(codec_config, PretrainedConfig): + self.codec_config = codec_config + + self.num_codebooks = num_codebooks + self.frame_size = self.codec_config.frame_size + + self.audio_bos_token_id = audio_bos_token_id + self.audio_pad_token_id = audio_pad_token_id + self.codebook_vocab_size = codebook_vocab_size + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + if ffn_dim % 2 == 1: + raise ValueError(f"`ffn_dim={ffn_dim}` must be even.") + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.sliding_window = sliding_window + + +__all__ = ["KyutaiSpeechToTextConfig"] diff --git a/src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py b/src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py new file mode 100644 index 0000000000..fe4a5a6bc6 --- /dev/null +++ b/src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py @@ -0,0 +1,377 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gc +import os +import re + +import safetensors.torch +import sentencepiece +import torch + +from transformers import ( + KyutaiSpeechToTextConfig, + KyutaiSpeechToTextFeatureExtractor, + KyutaiSpeechToTextForConditionalGeneration, + KyutaiSpeechToTextProcessor, + PreTrainedTokenizerFast, +) +from transformers.convert_slow_tokenizer import MoshiConverter +from transformers.utils.hub import cached_file + + +# fmt: off +MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"out_norm": r"norm", + r"gating\.linear_in": r"mlp.fc1", + r"gating\.linear_out": r"mlp.fc2", + r"self_attn\.out_proj": r"self_attn.o_proj.linear", + r"norm1": r"input_layernorm", + r"norm2": r"post_attention_layernorm", + r"layer_scale_1": r"self_attn_layer_scale", + r"layer_scale_2": r"mlp_layer_scale", + r"alpha": r"weight", +} +# fmt: on + + +# fmt: off +MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"conv\.conv\.conv": "conv", + r"convtr\.convtr\.convtr": "conv", + r"conv\.conv": "conv", + r"convtr\.convtr": "conv", + r"quantizer\.rvq_first\.vq": "quantizer.semantic_residual_vector_quantizer", + r"quantizer\.rvq_first": "quantizer.semantic_residual_vector_quantizer", + r"quantizer\.rvq_rest\.vq": "quantizer.acoustic_residual_vector_quantizer", + r"quantizer\.rvq_rest": "quantizer.acoustic_residual_vector_quantizer", + r"_codebook": "codebook", + r"_initialized": "initialized", + r"embedding_sum": "embed_sum", + r"encoder\.model": "encoder.layers", + r"decoder\.model": "decoder.layers", + r"encoder_transformer\.transformer": "encoder_transformer", + r"decoder_transformer\.transformer": "decoder_transformer", + r"linear1": "mlp.fc1", + r"linear2": "mlp.fc2", + r"self_attn\.out_proj": "self_attn.o_proj", + r"norm1": "input_layernorm", + r"norm2": "post_attention_layernorm", + r"layer_scale_1": "self_attn_layer_scale", + r"layer_scale_2": "mlp_layer_scale", +} +# fmt: on + + +def permute_for_rope(input_tensor, n_heads, dim1, dim2): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + return input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def convert_key(key, mapping): + for pattern, replacement in mapping.items(): + key = re.sub(pattern, replacement, key) + return key + + +def convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix="transformer."): + hidden_size = config.hidden_size + head_dim = config.head_dim + num_heads = int(config.hidden_size // config.head_dim) + num_key_value_heads = config.num_key_value_heads + key_value_head_dim = config.num_key_value_heads * head_dim + + # concat embeddings + embed_tokens_weight = [] + for i in range(32): + embed_tokens_weight.append(state_dict.pop(f"emb.{i}.weight")) + + embed_tokens_weight = torch.cat(embed_tokens_weight, dim=0) + embed_tokens_weight = torch.cat([state_dict.pop("text_emb.weight"), embed_tokens_weight]) + embed_tokens_weight = torch.cat([embed_tokens_weight, torch.zeros(1, config.hidden_size)], dim=0) + state_dict["embed_tokens.embed_tokens.weight"] = embed_tokens_weight + + for key, value in list(state_dict.items()): + if unwanted_prefix is not None and unwanted_prefix in key: + new_key = key[len(unwanted_prefix) :] + else: + new_key = key + + new_key = convert_key(new_key, MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING) + + # Post-process the current_parameter. + if "alpha" in key: + state_dict[key] = state_dict[key].squeeze() + + if "in_proj_weight" in new_key: + # split qkv into query key and value + mixed_qkv = state_dict.pop(key) + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + state_dict[new_key.replace("in_proj_weight", "q_proj.linear.weight")] = permute_for_rope( + query_layer, num_heads, hidden_size, hidden_size + ) + state_dict[new_key.replace("in_proj_weight", "k_proj.linear.weight")] = permute_for_rope( + key_layer, num_key_value_heads, key_value_head_dim, hidden_size + ) + + state_dict[new_key.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer + else: + state_dict[new_key] = state_dict.pop(key) + + return state_dict + + +def convert_mimi_state_dict(state_dict, config, unwanted_prefix=None): + hidden_size = config.hidden_size + head_dim = config.head_dim + num_heads = int(config.hidden_size // config.head_dim) + num_key_value_heads = config.num_key_value_heads + key_value_head_dim = config.num_key_value_heads * head_dim + + for key, value in list(state_dict.items()): + if unwanted_prefix is not None and unwanted_prefix in key: + new_key = key[len(unwanted_prefix) :] + else: + new_key = key + + new_key = convert_key(new_key, MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING) + + if "in_proj_weight" in new_key: + # split qkv into query key and value + mixed_qkv = state_dict.pop(key) + qkv_dim = mixed_qkv.size(0) // 3 + + query_layer = mixed_qkv[:qkv_dim] + key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] + value_layer = mixed_qkv[qkv_dim * 2 :] + + state_dict[new_key.replace("in_proj_weight", "q_proj.weight")] = permute_for_rope( + query_layer, num_heads, hidden_size, hidden_size + ) + state_dict[new_key.replace("in_proj_weight", "k_proj.weight")] = permute_for_rope( + key_layer, num_key_value_heads, key_value_head_dim, hidden_size + ) + state_dict[new_key.replace("in_proj_weight", "v_proj.weight")] = value_layer + else: + state_dict[new_key] = state_dict.pop(key) + + return state_dict + + +def write_model( + input_path_or_repo, + model_name, + codec_model_path_or_repo, + codec_model_name, + output_dir, + safe_serialization=True, + unwanted_prefix="transformer.", +): + print("Converting the model.") + os.makedirs(output_dir, exist_ok=True) + + config = KyutaiSpeechToTextConfig() + config.use_cache = True + config.codec_config.sliding_window = 250 + + model_path = cached_file( + input_path_or_repo, + model_name, + ) + + codec_path = cached_file( + codec_model_path_or_repo, + codec_model_name, + ) + + print(f"Fetching all parameters from the checkpoint at {model_path}...") + state_dict = safetensors.torch.load_file(model_path) + + print(f"Fetching all parameters from the checkpoint at {codec_path}...") + codec_state_dict = safetensors.torch.load_file(codec_path) + + print("Converting model...") + # ----------------------- + # convert parameter names + # ----------------------- + state_dict = convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix=unwanted_prefix) + codec_state_dict = convert_mimi_state_dict(codec_state_dict, config.codec_config, unwanted_prefix=None) + + # ------------------------- + # load the weights and save + # ------------------------- + print("Loading the checkpoint in a Moshi ASR model.") + with torch.device("meta"): + model = KyutaiSpeechToTextForConditionalGeneration(config) + + linear_weight = state_dict.pop("text_linear.weight") + model.model.load_state_dict(state_dict, strict=True, assign=True) + + linear_weight = torch.cat([linear_weight, torch.zeros(1, config.hidden_size)]) + model.lm_head.load_state_dict({"weight": linear_weight}, strict=True, assign=True) + + model.codec_model.load_state_dict(codec_state_dict, strict=True, assign=True) + + print("Checkpoint loaded successfully.") + del model.config._name_or_path + del model.config.codec_config._name_or_path + + # default generation config + model.generation_config._from_model_config = False + model.generation_config.audio_window_size = 1 + model.generation_config.cache_implementation = "sliding_window" + + model.codec_model.generation_config._from_model_config = False + model.codec_model.generation_config.cache_implementation = "sliding_window" + model.codec_model.generation_config.use_cache = True + + print("Saving the model.") + model.save_pretrained(output_dir, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + output_dir, torch_dtype=torch.bfloat16, device_map="auto" + ) + print("Model reloaded successfully.") + + +def write_processor( + input_path_or_repo, + tokenizer_model_name, + codec_model_path_or_repo, + output_dir, + audio_delay_seconds, + audio_silence_prefix_seconds, +): + tokenizer_path = cached_file( + input_path_or_repo, + tokenizer_model_name, + ) + + tokenizer = MoshiConverter(tokenizer_path).converted() + original_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path) + + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + chat_template=None, + unk_token="", + model_input_names=["input_ids", "attention_mask"], + clean_up_tokenization_spaces=False, + bos_token_id=original_tokenizer.bos_id(), + eos_token_id=original_tokenizer.eos_id(), + pad_token_id=original_tokenizer.pad_id(), + ) + + feature_extractor = KyutaiSpeechToTextFeatureExtractor( + audio_delay_seconds=audio_delay_seconds, + audio_silence_prefix_seconds=audio_silence_prefix_seconds, + ) + + processor = KyutaiSpeechToTextProcessor(feature_extractor, tokenizer) + processor.save_pretrained(output_dir) + print(f"Processor saved successfully to {output_dir}") + + +def main(): + parser = argparse.ArgumentParser(description="Convert Moshi ASR weights to HuggingFace format") + parser.add_argument( + "--input_path_or_repo", + type=str, + required=True, + help="Path or repo containing Moshi ASR weights", + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Name of the model in input_path_or_repo", + ) + parser.add_argument( + "--tokenizer_model_name", + type=str, + required=True, + help="Name of the tokenizer model in input_path_or_repo", + ) + parser.add_argument( + "--codec_model_path_or_repo", + type=str, + required=True, + help="Path or repo containing the Mimi weights", + ) + parser.add_argument( + "--mimi_name", + type=str, + required=True, + help="Name of the Mimi model in codec_model_path_or_repo", + ) + parser.add_argument( + "--preprocessor_model_path_or_repo", + type=str, + required=True, + help="Path or repo containing the preprocessor config", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`." + ) + parser.add_argument( + "--audio_delay_seconds", + type=float, + required=True, + help="Audio delay in seconds to add to the right of the input", + ) + parser.add_argument( + "--audio_silence_prefix_seconds", + type=float, + required=True, + help="Audio silence prefix in seconds to add to the left of the input", + ) + args = parser.parse_args() + + write_model( + args.input_path_or_repo, + args.model_name, + args.codec_model_path_or_repo, + args.mimi_name, + args.output_dir, + safe_serialization=args.safe_serialization, + ) + + write_processor( + args.input_path_or_repo, + args.tokenizer_model_name, + args.preprocessor_model_path_or_repo, + args.output_dir, + args.audio_delay_seconds, + args.audio_silence_prefix_seconds, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py b/src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py new file mode 100644 index 0000000000..94ddb15daa --- /dev/null +++ b/src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py @@ -0,0 +1,237 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs an KyutaiSpeechToText feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + chunk_length_s (`float`, *optional*): + If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. + overlap (`float`, *optional*): + Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following + formulae : `int((1.0 - self.overlap) * self.chunk_length)`. + audio_delay_seconds (`float`, *optional*, defaults to 0.0): + The delay in seconds to add after the audio (right padding). + audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0): + The silence prefix in seconds to add before the audio (left padding). + """ + + model_input_names = ["input_values", "padding_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 24000, + padding_value: float = 0.0, + chunk_length_s: Optional[float] = None, + overlap: Optional[float] = None, + audio_delay_seconds: Optional[float] = 0.0, + audio_silence_prefix_seconds: Optional[float] = 0.0, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.chunk_length_s = chunk_length_s + self.overlap = overlap + self.audio_delay_seconds = audio_delay_seconds + self.audio_silence_prefix_seconds = audio_silence_prefix_seconds + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_length(self) -> Optional[int]: + if self.chunk_length_s is None: + return None + else: + return int(self.chunk_length_s * self.sampling_rate) + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_stride(self) -> Optional[int]: + if self.chunk_length_s is None or self.overlap is None: + return None + else: + return max(1, int((1.0 - self.overlap) * self.chunk_length)) + + def __call__( + self, + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.shape[-1] != 2: + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + padded_inputs = None + input_values = BatchFeature({"input_values": raw_audio}) + if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: + if truncation: + max_length = min(array.shape[0] for array in raw_audio) + nb_step = int(np.floor(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + elif padding: + max_length = max(array.shape[0] for array in raw_audio) + nb_step = int(np.ceil(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + padding = "max_length" + else: + padded_inputs = input_values + + # normal padding on batch + if padded_inputs is None: + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=padding, + ) + + if padding: + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + # now let's padd left and right + pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate) + pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate) + padded_inputs["input_values"] = np.pad( + padded_inputs["input_values"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0.0, + ) + if padding: + padded_inputs["padding_mask"] = np.pad( + padded_inputs["padding_mask"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0, + ) + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["KyutaiSpeechToTextFeatureExtractor"] diff --git a/src/transformers/models/stt/modeling_kyutai_speech_to_text.py b/src/transformers/models/stt/modeling_kyutai_speech_to_text.py new file mode 100644 index 0000000000..7a86cd440c --- /dev/null +++ b/src/transformers/models/stt/modeling_kyutai_speech_to_text.py @@ -0,0 +1,1434 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import types +from typing import Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationConfig, GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, + flash_attn_supports_top_left_mask, + is_flash_attn_available, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ..auto import AutoModel +from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + # Ignore copy + def forward(self, x): + output = self._norm(x.float()) + output = output * self.weight.float() + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class KyutaiSpeechToTextFlexibleLinear(nn.Module): + def __init__(self, input_size, output_size, num_layers): + super().__init__() + # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size) + self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size)) + + def forward(self, x, layer_idx=None): + """ + `KyutaiSpeechToTextFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it. + In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence. + + For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`. + If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + + + Args: + x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)` + layer_idx (`torch.Tensor`, *optional*): + Can be used to specify which codebook's layers(s) to use. + If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. + But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. + """ + + # Use torch.gather to select the corresponding weights for each sample + # (codebooks, output_size, hidden_size) + selected_weights = torch.index_select(self.weight, 0, layer_idx) if layer_idx is not None else self.weight + + # (1, codebooks, hidden_size, output_size) + selected_weights = selected_weights.transpose(1, 2)[None, :, :, :] + + # (batch_size, codebooks, 1, hidden_size) x (1, codebooks, hidden_size, output_size) + # -> (batch_size, codebooks, 1, output_size) + x = torch.matmul(x[:, :, None, :], selected_weights) + + # (batch_size, codebooks, output_size) + return x.squeeze(2) + + +@auto_docstring +class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): + config_class = KyutaiSpeechToTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["KyutaiSpeechToTextDecoderLayer", "MimiTransformerLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + main_input_name = "input_ids" + + def _init_weights(self, module): + std = self.config.initializer_range + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, KyutaiSpeechToTextFlexibleLinear): + module.weight.data.normal_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, KyutaiSpeechToTextRMSNorm): + module.weight.data.fill_(1.0) + + +class KyutaiSpeechToTextConv1dPaddingCache: + """ + Padding cache for KyutaiSpeechToTextConv1d causal convolutions in order to support streaming via cache padding. + See: https://arxiv.org/pdf/2005.06720 & https://arxiv.org/pdf/2204.07064 + + A padding cache is a list of cached partial hidden states for each convolution layer. + Hidden states are cached from the previous call to the KyutaiSpeechToTextConv1d forward pass, given the padding size. + """ + + def __init__( + self, + num_layers: int, + per_layer_padding: list[int], + per_layer_padding_mode: list[str], + per_layer_in_channels: list[int], + ): + # ensure correct number of layers for each arg + from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)} + + if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers: + raise ValueError( + f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`" + ) + elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode): + raise NotImplementedError( + "`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode" + ) + + self.per_layer_padding = per_layer_padding + self.per_layer_padding_mode = per_layer_padding_mode + self.per_layer_in_channels = per_layer_in_channels + self.per_layer_is_init = [True] * num_layers + + self.padding_cache = [None] * num_layers + + def update(self, hidden_states: torch.Tensor, layer_idx: int): + """ + Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache. + + Parameters: + hidden_states (`torch.Tensor`): + The hidden states to be partially cached. + layer_idx (`int`): + The index of the layer to cache the states for. + Returns: + `torch.Tensor` or `None`, the current padding cache. + """ + batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device + padding = self.per_layer_padding[layer_idx] + padding_mode = self.per_layer_padding_mode[layer_idx] + in_channels = self.per_layer_in_channels[layer_idx] + + if self.padding_cache[layer_idx] is None: + if padding_mode == "constant": + current_cache = torch.zeros( + batch_size, + in_channels, + padding, + device=device, + dtype=dtype, + ) + elif padding_mode == "replicate": + current_cache = ( + torch.ones( + batch_size, + in_channels, + padding, + device=device, + dtype=dtype, + ) + * hidden_states[..., :1] + ) + else: + current_cache = self.padding_cache[layer_idx] + + # update the cache + if padding > 0: + padding_states = hidden_states[:, :, -padding:] + else: + padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device) + self.padding_cache[layer_idx] = padding_states + + return current_cache + + +class KyutaiSpeechToTextEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_tokens = nn.Embedding( + config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1, + config.hidden_size, + padding_idx=config.audio_pad_token_id, + ) + audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size + audio_tokens_offsets += config.vocab_size + audio_tokens_offsets = nn.functional.pad( + audio_tokens_offsets, (1, 0) + ) # pad one 0 to the left for the text token + self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False) + + def forward(self, input_ids): + input_ids = torch.where( + input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets + ) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.sum(dim=2) + return inputs_embeds + + +class KyutaiSpeechToTextLinear(nn.Module): + def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False): + super().__init__() + + self.use_flexible_linear = use_flexible_linear + + if not use_flexible_linear: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = KyutaiSpeechToTextFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks) + + def forward(self, x, layer_idx=None): + if self.use_flexible_linear: + return self.linear(x, layer_idx) + else: + return self.linear(x) + + +class KyutaiSpeechToTextRotaryEmbedding(nn.Module): + def __init__(self, config: KyutaiSpeechToTextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class KyutaiSpeechToTextGatingMLP(nn.Module): + def __init__(self, config, use_flexible_linear=False): + super().__init__() + + self.activation_fn = ACT2FN[config.hidden_act] + ffn_dim = config.ffn_dim + hidden_size = config.hidden_size + num_layers = config.num_codebooks if use_flexible_linear else 1 + if num_layers == 1: + self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False) + self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False) + else: + self.fc1 = KyutaiSpeechToTextFlexibleLinear(hidden_size, ffn_dim, num_layers) + self.fc2 = KyutaiSpeechToTextFlexibleLinear(ffn_dim // 2, hidden_size, num_layers) + + def forward(self, hidden_states: torch.Tensor, layer_idx: Optional[int] = None) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx) + + batch_size, sequence_length, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1) + hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :] + hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class KyutaiSpeechToTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: KyutaiSpeechToTextConfig, + layer_idx: Optional[int] = None, + use_flexible_linear=False, + use_rope=True, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + self.scaling = 1 / math.sqrt(self.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = KyutaiSpeechToTextLinear( + self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.k_proj = KyutaiSpeechToTextLinear( + self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.v_proj = KyutaiSpeechToTextLinear( + self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear + ) + self.o_proj = KyutaiSpeechToTextLinear( + self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear + ) + + # rotary embeddings are not used in the depth decoder + self.rotary_emb = None + if use_rope: + self.rope_theta = config.rope_theta + self.rotary_emb = KyutaiSpeechToTextRotaryEmbedding(config) + + # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward + # no longer copied after attention refactors + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->KyutaiSpeechToText +# TODO cyril: modular +class KyutaiSpeechToTextFlashAttention2(KyutaiSpeechToTextAttention): + """ + KyutaiSpeechToText flash attention module. This module inherits from `KyutaiSpeechToTextAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (KyutaiSpeechToTextRMSNorm handles it correctly) + + input_dtype = query_states.dtype + device_type = query_states.device.type if query_states.device.type != "mps" else "cpu" + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = ( + torch.get_autocast_dtype(device_type) + if hasattr(torch, "get_autocast_dtype") + else torch.get_autocast_gpu_dtype() + ) + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->KyutaiSpeechToText +# TODO cyril: modular +class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention): + """ + KyutaiSpeechToText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `KyutaiSpeechToTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from KyutaiSpeechToTextAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "KyutaiSpeechToTextModel is using KyutaiSpeechToTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.rotary_emb is not None: # Ignore copy + cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = ( + {"sin": sin, "cos": cos, "cache_position": cache_position} + if self.rotary_emb is not None + else {"cache_position": cache_position} + ) # Ignore copy + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + + return attn_output, None, past_key_value + + +STT_ATTENTION_CLASSES = { + "eager": KyutaiSpeechToTextAttention, + "flash_attention_2": KyutaiSpeechToTextFlashAttention2, + "sdpa": KyutaiSpeechToTextSdpaAttention, +} + + +class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: KyutaiSpeechToTextConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True): + super().__init__() + self.hidden_size = config.hidden_size + self.use_flexible_linear = use_flexible_linear + + self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope + ) + + self.mlp = KyutaiSpeechToTextGatingMLP(config, use_flexible_linear) + self.input_layernorm = KyutaiSpeechToTextRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = KyutaiSpeechToTextRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + self._attn_implementation = config._attn_implementation + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = ( + self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position) + ) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@auto_docstring +class KyutaiSpeechToTextModel(KyutaiSpeechToTextPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = KyutaiSpeechToTextEmbeddings(config) + self.layers = nn.ModuleList( + [ + KyutaiSpeechToTextDecoderLayer(config, layer_idx, use_flexible_linear=False) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = KyutaiSpeechToTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False # noqa: F841 + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True # noqa: F841 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of KyutaiSpeechToText. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + config: KyutaiSpeechToTextConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`KyutaiSpeechToTextConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( + cache_position.reshape(-1, 1) - text_config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring +class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _keep_in_fp32_modules = ["codec_model"] + + def __init__(self, config): + super().__init__(config) + self.model = KyutaiSpeechToTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.codec_model = AutoModel.from_config(config.codec_config) + + # we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None + # yet the codec_model needs a generation config to initalize it's cache for streaming inference + # we therefore initialize a generation config for the codec model + self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import torch + >>> from datasets import load_dataset, Audio + >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration + + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model_id = "kyutai/stt-2.6b-en" + + >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) + >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + + >>> ds = load_dataset( + ... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ... ) + + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + >>> inputs = processor( + ... ds[0]["audio"]["array"], + ... ) + >>> inputs.to(torch_device) + + >>> output_tokens = model.generate(**inputs) + >>> print(processor.batch_decode(output_tokens, skip_special_tokens=True)) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _prepare_generation_config(self, *args, **kwargs): + generation_config, model_kwargs = super()._prepare_generation_config(*args, **kwargs) + # this should be passed to the model kwargs for the input preparation + model_kwargs["audio_window_size"] = ( + generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None + ) + return generation_config, model_kwargs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + inputs, input_name, model_kwargs = super()._prepare_model_inputs( + inputs=inputs, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + audio_window_size = model_kwargs.get("audio_window_size", None) + if audio_window_size is None: + audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item() + model_kwargs["audio_window_size"] = audio_window_size + + batch_size = inputs.shape[0] + device = inputs.device + + # initialize audio tokens + model_kwargs["audio_tokens"] = torch.zeros( + (batch_size, audio_window_size, self.config.num_codebooks), + device=device, + dtype=torch.long, + ) + model_kwargs["current_window"] = ( + torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous() + ) + + # let's use generate's cache preparation to prepare the cache for the codec model + temporary_model_kwargs = {} + + # monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin + # Add cache-related methods from GenerationMixin to codec model + cache_methods = [ + "_prepare_cache_for_generation", + "_get_cache", + "_supports_default_dynamic_cache", + "_get_layer_device_map_for_cache_init", + ] + for method in cache_methods: + setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) + + self.codec_model._prepare_cache_for_generation( + generation_config=self.codec_model.generation_config, + model_kwargs=temporary_model_kwargs, + assistant_model=None, + batch_size=batch_size, + max_cache_length=self.config.codec_config.sliding_window, + device=device, + ) + + if "past_key_values" in temporary_model_kwargs: + model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"] + + # initialize the padding cache for the codec model + per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], [] + for layer_name in self.codec_model.encoder._mimiconv1d_layer_names: + per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total) + per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode) + per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels) + + # downsample layer + per_layer_padding.append(self.codec_model.downsample.padding_total) + per_layer_padding_mode.append(self.codec_model.downsample.pad_mode) + per_layer_in_channels.append(self.codec_model.downsample.in_channels) + + model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache( + num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1, + per_layer_padding=per_layer_padding, + per_layer_padding_mode=per_layer_padding_mode, + per_layer_in_channels=per_layer_in_channels, + ) + + return inputs, input_name, model_kwargs + + def prepare_inputs_for_generation( + self, + *args, + audio_tokens: Optional[torch.LongTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.Tensor] = None, + audio_window_size: Optional[int] = None, + current_window: Optional[tuple[int, int]] = None, + encoder_past_key_values: Optional[Cache] = None, + padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + if input_values is not None: + cache_position = model_inputs["cache_position"] + start, end = current_window[0] + + # first cache position is for bos token, so we need to offset by -1 + if cache_position[-1] - 1 >= end: + # we need to encode the new audio tokens + with torch.no_grad(): + input_values_start_idx = start * self.config.frame_size + input_values_end_idx = (start + audio_window_size) * self.config.frame_size + current_input_values = input_values[..., input_values_start_idx:input_values_end_idx] + codec_model_output = self.codec_model.encode( + current_input_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + ) + new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2) + + audio_tokens.copy_(new_audio_tokens) + + start = end.clone() + end = end + audio_window_size + current_window.copy_( + torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1) + ) + + # first cache position is for bos token, so we need to offset by -1 + current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0) + current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :] + + current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id + + input_ids = model_inputs.pop("input_ids") + input_ids = torch.cat( + [input_ids.unsqueeze(2), current_audio_tokens], + dim=2, + ) + model_inputs["input_ids"] = input_ids + + return model_inputs + + # TODO: @eustlb, this should be standardized + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get("output_loading_info", False): + model, loading_info = super().from_pretrained(*args, **kwargs) + else: + model = super().from_pretrained(*args, **kwargs) + + # copy depth decoder generation conf attr to the depth decoder generation config + prefix = "codec_" + prefix_len = len(prefix) + codec_model_attrs = { + attr[prefix_len:]: value + for attr, value in vars(model.generation_config).items() + if attr.startswith(prefix) + } + + vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs}) + + # remove the depth decoder generation conf attr from the model generation config + for attr in codec_model_attrs: + delattr(model.generation_config, prefix + attr) + + if "output_loading_info" in kwargs: + return model, loading_info + else: + return model + + # TODO: @eustlb, this should be standardized + def save_pretrained(self, *args, **kwargs): + prefix = "codec_" + codec_model_attrs = self.codec_model.generation_config.to_diff_dict() + codec_model_attrs.pop("transformers_version", None) + for attr, value in codec_model_attrs.items(): + setattr(self.generation_config, prefix + attr, value) + + super().save_pretrained(*args, **kwargs) + + def generate(self, *args, **kwargs): + r""" + This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information. + """ + max_new_tokens = kwargs.pop("max_new_tokens", None) + input_values = kwargs.get("input_values") + + # TODO: @eustlb, we should have per-batch-idx values + # here we do not use padding_mask to be aligned to what's done in the original codebase + max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size + + if max_new_tokens is None or max_new_tokens > max_audio_frames: + if max_new_tokens is not None: + logger.warning( + f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})." + f"Setting `max_new_tokens` to {max_audio_frames}." + ) + max_new_tokens = max_audio_frames + + return super().generate( + *args, + max_new_tokens=max_new_tokens, + **kwargs, + ) + + +__all__ = [ + "KyutaiSpeechToTextPreTrainedModel", + "KyutaiSpeechToTextModel", + "KyutaiSpeechToTextForConditionalGeneration", +] diff --git a/src/transformers/models/stt/modular_kyutai_speech_to_text.py b/src/transformers/models/stt/modular_kyutai_speech_to_text.py new file mode 100644 index 0000000000..8cc0c9d2a7 --- /dev/null +++ b/src/transformers/models/stt/modular_kyutai_speech_to_text.py @@ -0,0 +1,510 @@ +# coding=utf-8 +# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...cache_utils import Cache +from ...feature_extraction_utils import BatchFeature +from ...generation import GenerationConfig, GenerationMixin +from ...modeling_utils import PreTrainedModel +from ...utils import PaddingStrategy, TensorType, logging +from ..auto import AutoModel +from ..encodec.feature_extraction_encodec import EncodecFeatureExtractor +from ..llama.modeling_llama import LlamaForCausalLM +from ..mimi.modeling_mimi import MimiConv1dPaddingCache +from ..moshi.modeling_moshi import MoshiModel, MoshiPreTrainedModel + + +logger = logging.get_logger(__name__) + + +class KyutaiSpeechToTextFeatureExtractor(EncodecFeatureExtractor): + r""" + Constructs an KyutaiSpeechToText feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + chunk_length_s (`float`, *optional*): + If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. + overlap (`float`, *optional*): + Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following + formulae : `int((1.0 - self.overlap) * self.chunk_length)`. + audio_delay_seconds (`float`, *optional*, defaults to 0.0): + The delay in seconds to add after the audio (right padding). + audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0): + The silence prefix in seconds to add before the audio (left padding). + """ + + def __init__( + self, + audio_delay_seconds: Optional[float] = 0.0, + audio_silence_prefix_seconds: Optional[float] = 0.0, + **super_kwargs, + ): + super().__init__(**super_kwargs) + self.audio_delay_seconds = audio_delay_seconds + self.audio_silence_prefix_seconds = audio_silence_prefix_seconds + + def __call__( + self, + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.shape[-1] != 2: + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + padded_inputs = None + input_values = BatchFeature({"input_values": raw_audio}) + if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: + if truncation: + max_length = min(array.shape[0] for array in raw_audio) + nb_step = int(np.floor(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + elif padding: + max_length = max(array.shape[0] for array in raw_audio) + nb_step = int(np.ceil(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + padding = "max_length" + else: + padded_inputs = input_values + + # normal padding on batch + if padded_inputs is None: + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=padding, + ) + + if padding: + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + # now let's padd left and right + pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate) + pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate) + padded_inputs["input_values"] = np.pad( + padded_inputs["input_values"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0.0, + ) + if padding: + padded_inputs["padding_mask"] = np.pad( + padded_inputs["padding_mask"], + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0, + ) + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +class KyutaiSpeechToTextPreTrainedModel(MoshiPreTrainedModel): + pass + + +class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache): + pass + + +class KyutaiSpeechToTextEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_tokens = nn.Embedding( + config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1, + config.hidden_size, + padding_idx=config.audio_pad_token_id, + ) + audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size + audio_tokens_offsets += config.vocab_size + audio_tokens_offsets = nn.functional.pad( + audio_tokens_offsets, (1, 0) + ) # pad one 0 to the left for the text token + self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False) + + def forward(self, input_ids): + input_ids = torch.where( + input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets + ) + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.sum(dim=2) + return inputs_embeds + + +class KyutaiSpeechToTextModel(MoshiModel): + def __init__(self, config): + super().__init__(config) + self.embed_tokens = KyutaiSpeechToTextEmbeddings(config) + + +class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel): + _keep_in_fp32_modules = ["codec_model"] + + def __init__(self, config): + super().__init__(config) + self.codec_model = AutoModel.from_config(config.codec_config) + + # we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None + # yet the codec_model needs a generation config to initalize it's cache for streaming inference + # we therefore initialize a generation config for the codec model + self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config) + + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> import torch + >>> from datasets import load_dataset, Audio + >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration + + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + >>> model_id = "kyutai/stt-2.6b-en" + + >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) + >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + + >>> ds = load_dataset( + ... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ... ) + + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + >>> inputs = processor( + ... ds[0]["audio"]["array"], + ... ) + >>> inputs.to(torch_device) + + >>> output_tokens = model.generate(**inputs) + >>> print(processor.batch_decode(output_tokens, skip_special_tokens=True)) + ```""" + super().forward(**super_kwargs) + + def _prepare_generation_config(self, *args, **kwargs): + generation_config, model_kwargs = GenerationMixin._prepare_generation_config(*args, **kwargs) + # this should be passed to the model kwargs for the input preparation + model_kwargs["audio_window_size"] = ( + generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None + ) + return generation_config, model_kwargs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs( + inputs=inputs, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + audio_window_size = model_kwargs.get("audio_window_size", None) + if audio_window_size is None: + audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item() + model_kwargs["audio_window_size"] = audio_window_size + + batch_size = inputs.shape[0] + device = inputs.device + + # initialize audio tokens + model_kwargs["audio_tokens"] = torch.zeros( + (batch_size, audio_window_size, self.config.num_codebooks), + device=device, + dtype=torch.long, + ) + model_kwargs["current_window"] = ( + torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous() + ) + + # let's use generate's cache preparation to prepare the cache for the codec model + temporary_model_kwargs = {} + + # monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin + # Add cache-related methods from GenerationMixin to codec model + cache_methods = [ + "_prepare_cache_for_generation", + "_get_cache", + "_supports_default_dynamic_cache", + "_get_layer_device_map_for_cache_init", + ] + for method in cache_methods: + setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model)) + + self.codec_model._prepare_cache_for_generation( + generation_config=self.codec_model.generation_config, + model_kwargs=temporary_model_kwargs, + assistant_model=None, + batch_size=batch_size, + max_cache_length=self.config.codec_config.sliding_window, + device=device, + ) + + if "past_key_values" in temporary_model_kwargs: + model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"] + + # initialize the padding cache for the codec model + per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], [] + for layer_name in self.codec_model.encoder._mimiconv1d_layer_names: + per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total) + per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode) + per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels) + + # downsample layer + per_layer_padding.append(self.codec_model.downsample.padding_total) + per_layer_padding_mode.append(self.codec_model.downsample.pad_mode) + per_layer_in_channels.append(self.codec_model.downsample.in_channels) + + model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache( + num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1, + per_layer_padding=per_layer_padding, + per_layer_padding_mode=per_layer_padding_mode, + per_layer_in_channels=per_layer_in_channels, + ) + + return inputs, input_name, model_kwargs + + def prepare_inputs_for_generation( + self, + *args, + audio_tokens: Optional[torch.LongTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.Tensor] = None, + audio_window_size: Optional[int] = None, + current_window: Optional[tuple[int, int]] = None, + encoder_past_key_values: Optional[Cache] = None, + padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None, + **kwargs, + ): + model_inputs = GenerationMixin.prepare_inputs_for_generation(*args, **kwargs) + + if input_values is not None: + cache_position = model_inputs["cache_position"] + start, end = current_window[0] + + # first cache position is for bos token, so we need to offset by -1 + if cache_position[-1] - 1 >= end: + # we need to encode the new audio tokens + with torch.no_grad(): + input_values_start_idx = start * self.config.frame_size + input_values_end_idx = (start + audio_window_size) * self.config.frame_size + current_input_values = input_values[..., input_values_start_idx:input_values_end_idx] + codec_model_output = self.codec_model.encode( + current_input_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + ) + new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2) + + audio_tokens.copy_(new_audio_tokens) + + start = end.clone() + end = end + audio_window_size + current_window.copy_( + torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1) + ) + + # first cache position is for bos token, so we need to offset by -1 + current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0) + current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :] + + current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id + + input_ids = model_inputs.pop("input_ids") + input_ids = torch.cat( + [input_ids.unsqueeze(2), current_audio_tokens], + dim=2, + ) + model_inputs["input_ids"] = input_ids + + return model_inputs + + # TODO: @eustlb, this should be standardized + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get("output_loading_info", False): + model, loading_info = PreTrainedModel.from_pretrained(*args, **kwargs) + else: + model = PreTrainedModel.from_pretrained(*args, **kwargs) + + # copy depth decoder generation conf attr to the depth decoder generation config + prefix = "codec_" + prefix_len = len(prefix) + codec_model_attrs = { + attr[prefix_len:]: value + for attr, value in vars(model.generation_config).items() + if attr.startswith(prefix) + } + + vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs}) + + # remove the depth decoder generation conf attr from the model generation config + for attr in codec_model_attrs: + delattr(model.generation_config, prefix + attr) + + if "output_loading_info" in kwargs: + return model, loading_info + else: + return model + + # TODO: @eustlb, this should be standardized + def save_pretrained(self, *args, **kwargs): + prefix = "codec_" + codec_model_attrs = self.codec_model.generation_config.to_diff_dict() + codec_model_attrs.pop("transformers_version", None) + for attr, value in codec_model_attrs.items(): + setattr(self.generation_config, prefix + attr, value) + + PreTrainedModel.save_pretrained(self, *args, **kwargs) + + def generate(self, *args, **kwargs): + r""" + This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information. + """ + max_new_tokens = kwargs.pop("max_new_tokens", None) + input_values = kwargs.get("input_values") + + # TODO: @eustlb, we should have per-batch-idx values + # here we do not use padding_mask to be aligned to what's done in the original codebase + max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size + + if max_new_tokens is None or max_new_tokens > max_audio_frames: + if max_new_tokens is not None: + logger.warning( + f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})." + f"Setting `max_new_tokens` to {max_audio_frames}." + ) + max_new_tokens = max_audio_frames + + return GenerationMixin.generate( + *args, + max_new_tokens=max_new_tokens, + **kwargs, + ) + + +__all__ = [ + "KyutaiSpeechToTextPreTrainedModel", + "KyutaiSpeechToTextModel", + "KyutaiSpeechToTextForConditionalGeneration", + "KyutaiSpeechToTextFeatureExtractor", +] diff --git a/src/transformers/models/stt/processing_kyutai_speech_to_text.py b/src/transformers/models/stt/processing_kyutai_speech_to_text.py new file mode 100644 index 0000000000..0b3a021712 --- /dev/null +++ b/src/transformers/models/stt/processing_kyutai_speech_to_text.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ...audio_utils import AudioInput, make_list_of_audio +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack + + +class KyutaiSpeechToTextProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "audio_kwargs": { + "sampling_rate": 24000, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class KyutaiSpeechToTextProcessor(ProcessorMixin): + r""" + Constructs a Moshi ASR processor which wraps [`EncodecFeatureExtractor`] and + [`PreTrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and + tokenizer functionalities. See the [`~KyutaiSpeechToTextProcessor.__call__`] for more + information. + """ + + feature_extractor_class = "KyutaiSpeechToTextFeatureExtractor" + tokenizer_class = "PreTrainedTokenizerFast" + + def __call__( + self, + audio: Optional[AudioInput] = None, + **kwargs: Unpack[KyutaiSpeechToTextProcessorKwargs], + ): + r""" + Main method to prepare audio to be fed as input to the model. This method forwards the `audio` + arguments to KyutaiSpeechToTextFeatureExtractor's [`~KyutaiSpeechToTextFeatureExtractor.__call__`]. Please refer + to the docstring of the above method for more information. + + Args: + audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch + tensor. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`. + - **padding_mask** -- List of indices specifying which input values should be ignored by the model. + """ + + if audio is None: + raise ValueError("`audio` is required.") + + output_kwargs = self._merge_kwargs( + KyutaiSpeechToTextProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + + # ensure audio in correct format + audio = make_list_of_audio(audio) + + inputs = self.feature_extractor( + audio, + **audio_kwargs, + ) + + return inputs + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + +__all__ = ["KyutaiSpeechToTextProcessor"] diff --git a/tests/models/kyutai_speech_to_text/__init__.py b/tests/models/kyutai_speech_to_text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py new file mode 100644 index 0000000000..a6e08f714f --- /dev/null +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -0,0 +1,704 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Moshi ASR model.""" + +import gc +import inspect +import tempfile +import unittest + +import datasets +import pytest +from parameterized import parameterized + +from transformers import ( + KyutaiSpeechToTextConfig, + KyutaiSpeechToTextForConditionalGeneration, + KyutaiSpeechToTextProcessor, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_torch, + require_torch_accelerator, + require_torch_sdpa, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + KyutaiSpeechToTextForConditionalGeneration, + KyutaiSpeechToTextModel, + ) + + +class KyutaiSpeechToTextModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + text_seq_length=1, + input_values_length=192, # gives 3 audio tokens, corresponding to the default in GenerationTesterMixin + is_training=False, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + codebook_vocab_size=2049, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=None, + max_position_embeddings=512, + rope_theta=10000.0, + hidden_act="silu", + head_dim=None, + initializer_range=0.02, + use_cache=True, + sliding_window=512, + attention_dropout=0.1, + ffn_dim=38, + rms_norm_eps=1e-6, + num_codebooks=8, + frame_size=64, + delay_in_tokens=5, + audio_bos_token_id=2048, + audio_pad_token_id=2048, + tie_word_embeddings=False, + pad_token_id=0, + bos_token_id=1, + codec_config={ + "model_type": "mimi", + "num_quantizers": 8, + "audio_channels": 1, + "chunk_in_sec": None, + "hidden_size": 16, + "num_filters": 8, + "num_residual_layers": 1, + "upsampling_ratios": [8, 4], + "codebook_size": 16, + "vector_quantization_hidden_dimension": 16, + "upsample_groups": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "sliding_window": 4, + "codebook_dim": 16, + "use_cache": False, + }, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.text_seq_length = text_seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.codebook_vocab_size = codebook_vocab_size + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.hidden_act = hidden_act + self.head_dim = head_dim + self.initializer_range = initializer_range + self.use_cache = use_cache + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + self.ffn_dim = ffn_dim + self.rms_norm_eps = rms_norm_eps + self.num_codebooks = num_codebooks + self.frame_size = frame_size + self.delay_in_tokens = delay_in_tokens + self.audio_bos_token_id = audio_bos_token_id + self.audio_pad_token_id = audio_pad_token_id + self.tie_word_embeddings = tie_word_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.codec_config = codec_config + self.scope = scope + self.input_values_length = input_values_length + + def get_config(self): + return KyutaiSpeechToTextConfig( + codebook_vocab_size=self.codebook_vocab_size, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + max_position_embeddings=self.max_position_embeddings, + rope_theta=self.rope_theta, + hidden_act=self.hidden_act, + head_dim=self.head_dim, + initializer_range=self.initializer_range, + use_cache=self.use_cache, + sliding_window=self.sliding_window, + attention_dropout=self.attention_dropout, + ffn_dim=self.ffn_dim, + rms_norm_eps=self.rms_norm_eps, + num_codebooks=self.num_codebooks, + frame_size=self.frame_size, + delay_in_tokens=self.delay_in_tokens, + audio_bos_token_id=self.audio_bos_token_id, + audio_pad_token_id=self.audio_pad_token_id, + tie_word_embeddings=self.tie_word_embeddings, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + codec_config=self.codec_config, + ) + + def create_and_check_model(self, config, input_ids, input_mask): + model = KyutaiSpeechToTextModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def prepare_config_and_inputs(self): + config = self.get_config() + + text_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1 + codebook_input_ids = ( + ids_tensor([self.batch_size, self.seq_length, self.num_codebooks], self.codebook_vocab_size - 1) + 1 + ) + + input_ids = torch.cat([text_input_ids.unsqueeze(2), codebook_input_ids], dim=2) + attention_mask = text_input_ids.ne(1).to(torch_device) + + return config, input_ids, attention_mask + + def prepare_config_and_inputs_generate(self): + config = self.get_config() + + input_ids = torch.ones([self.batch_size, 1], dtype=torch.long, device=torch_device) + input_values = floats_tensor([self.batch_size, 1, self.input_values_length]) + padding_mask = torch.ones_like(input_values, dtype=torch.int32, device=torch_device) + + return config, input_ids, input_values, padding_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + return config, inputs_dict + + def prepare_config_and_inputs_for_common_generate(self): + config_and_inputs = self.prepare_config_and_inputs_generate() + ( + config, + input_ids, + input_values, + padding_mask, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "input_values": input_values, + "padding_mask": padding_mask, + } + return config, inputs_dict + + +@require_torch +class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + KyutaiSpeechToTextModel, + KyutaiSpeechToTextForConditionalGeneration, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": KyutaiSpeechToTextModel, + "automatic-speech-recognition": KyutaiSpeechToTextForConditionalGeneration, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def setUp(self): + self.model_tester = KyutaiSpeechToTextModelTester(self) + self.config_tester = ConfigTester(self, config_class=KyutaiSpeechToTextConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels) + + return inputs_dict + + def prepare_config_and_inputs_for_generate(self, batch_size=2): + # monkey patch prepare_config_and_inputs_for_common + + prepare_config_and_inputs_for_common = self.model_tester.prepare_config_and_inputs_for_common + original_batch_size = self.model_tester.batch_size + + self.model_tester.prepare_config_and_inputs_for_common = ( + self.model_tester.prepare_config_and_inputs_for_common_generate + ) + self.model_tester.batch_size = batch_size + + config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate() + self.model_tester.prepare_config_and_inputs_for_common = prepare_config_and_inputs_for_common + + self.model_tester.batch_size = original_batch_size + return config, filtered_inputs_dict + + @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).") + def test_model_get_set_embeddings(self): + pass + + @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).") + def test_tie_model_weights(self): + pass + + @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).") + def test_resize_embeddings_untied(self): + pass + + @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).") + def test_resize_tokens_embeddings(self): + pass + + @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).") + def test_tied_weights_keys(self): + pass + + @pytest.mark.skip(reason="Does not apply to Moshi ASR that requires input_values.") + def test_generate_without_input_ids(self): + pass + + def test_initialization(self): + """ + Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model. + See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397 + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv", "input_proj", "output_proj"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + def test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions): + self.skipTest("Test is failing, fix me :) ") + parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName) + parent_parameterized_test(self) + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): + pass + + @pytest.mark.generate + def test_left_padding_compatibility(self): + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding + # - The model must have generative capabilities + if len(self.all_generative_model_classes) == 0: + self.skipTest(reason="No generative architecture available for this model.") + + # - The model must support padding + if not self.has_attentions: + self.skipTest(reason="This model doesn't support padding.") + + # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:]) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat( + (torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1 + ) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) + + def test_generate_continue_from_past_key_values(self): + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + generate_kwargs = { + "pad_token_id": -1, + "eos_token_id": -1, + "forced_eos_token_id": None, + "encoder_no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + _, inputs = self.prepare_config_and_inputs_for_generate() + outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) + + # Let's generate again, but passing the past key values in between (2 + 1 = 3 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=2) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + if config.is_encoder_decoder: + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + inputs["decoder_attention_mask"], + (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]), + mode="constant", + value=1, + ) + else: + inputs["input_ids"] = outputs_cached.sequences + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + first_caches_scores = outputs_cached.scores + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) + full_cached_scores = first_caches_scores + outputs_cached.scores + outputs_cached.scores = full_cached_scores + + # The two sets of generated text and past kv should be equal to each other + self._check_similar_generate_outputs(outputs, outputs_cached) + for layer_idx in range(len(outputs_cached.past_key_values)): + for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + outputs_cached.past_key_values[layer_idx][kv_idx], + ) + ) + + # needs to be overridden to avoid to avoid casting of input_values to float16 + # indeed, the codec model is kept in fp32, so we need to avoid casting input_values to float16 + def _test_attention_implementation(self, attn_implementation): + """ + Compares the output of generate with the eager attention implementation against other implementations. + NOTE: despite the test logic being the same, different implementations actually need different decorators, hence + this separate function. + """ + max_new_tokens = 30 + support_flag = { + "sdpa": "_supports_sdpa", + "flash_attention_2": "_supports_flash_attn_2", + } + + for model_class in self.all_generative_model_classes: + if not getattr(model_class, support_flag[attn_implementation]): + self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`") + + config, original_inputs_dict = self.prepare_config_and_inputs_for_generate() + inputs_dict = {} + for input_name, input_data in original_inputs_dict.items(): + if ( + isinstance(input_data, torch.Tensor) + and input_data.dtype in [torch.float32, torch.bfloat16] + and input_name != "input_values" + ): + inputs_dict[input_name] = input_data.to(torch.float16) + else: + inputs_dict[input_name] = input_data + main_input = inputs_dict[model_class.main_input_name] + + # FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded + # attention masks at test time and, with generate, the mask will be appended with 1s on the right, + # resulting in a mask with holes (not supported properly by FA2). + if attn_implementation == "flash_attention_2": + for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"): + if input_name in inputs_dict: + inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name]) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + del model + gc.collect() + + generate_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + "use_cache": True, + } + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="eager", + ).to(torch_device) + res_eager = model_eager.generate(**inputs_dict, **generate_kwargs) + del model_eager + gc.collect() + + model_attn = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation=attn_implementation, + ).to(torch_device) + res_attn = model_attn.generate(**inputs_dict, **generate_kwargs) + del model_attn + gc.collect() + + self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3) + + +class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase): + _dataset = None + + def setUp(self): + self.model_checkpoint = "kyutai/stt-2.6b-en" + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @classmethod + def _load_dataset(cls): + # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. + if cls._dataset is None: + cls._dataset = datasets.load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) + # using 24000 here for simplicity, should rather be processor.feature_extractor.sampling_rate + cls._dataset = cls._dataset.cast_column("audio", datasets.Audio(sampling_rate=24000)) + + def _load_datasamples(self, num_samples): + self._load_dataset() + ds = self._dataset + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + return [x["array"] for x in speech_samples] + + @slow + @require_torch_accelerator + def test_generation(self): + """ + reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/7a9aa6139d11e0103c6b65bac103da52 + + DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible + as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght, + ultimately giving different outputs. + """ + processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint) + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device + ) + + samples = self._load_datasamples(1) + inputs = processor( + samples, + ).to(torch_device) + + out = model.generate(**inputs) + + # fmt: off + EXPECTED_TOKENS = torch.tensor([ + [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]], + ) + # fmt: on + + torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS) + + @slow + @require_torch_accelerator + def test_generation_batched(self): + """ + reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/b58c217c75124d405ec1c13877c7ece8 + + DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible + as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght, + ultimately giving different outputs. + """ + processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint) + model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained( + self.model_checkpoint, device_map=torch_device + ) + + samples = self._load_datasamples(4) + inputs = processor( + samples, + ).to(torch_device) + + out = model.generate(**inputs) + + # fmt: off + EXPECTED_TOKENS = torch.tensor([ + [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], + [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], + [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], + [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], + ]) + # fmt: on + + torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS) diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index bf48f34ce1..d9b0216b15 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -107,14 +107,21 @@ class MimiModelTester: self.sliding_window = sliding_window self.use_cache = use_cache - def prepare_config_and_inputs(self): - input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0) + def prepare_config_and_inputs(self, input_values_length=None): + input_values = floats_tensor( + [ + self.batch_size, + self.num_channels, + self.intermediate_size if input_values_length is None else input_values_length, + ], + scale=1.0, + ) config = self.get_config() inputs_dict = {"input_values": input_values} return config, inputs_dict - def prepare_config_and_inputs_for_common(self): - config, inputs_dict = self.prepare_config_and_inputs() + def prepare_config_and_inputs_for_common(self, input_values_length=None): + config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length) return config, inputs_dict def prepare_config_and_inputs_for_model_class(self, model_class): @@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase): ) self.assertTrue(rmse < 1e-3) + def test_integration_encode_with_padding_cache(self): + """ + We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk. + 1. we encode a first time the entire audio + 2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size) + + This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures. + """ + librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + model_id = "kyutai/mimi" + + model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu") + processor = AutoFeatureExtractor.from_pretrained(model_id) + + librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) + audio_sample = librispeech_dummy[-1]["audio"]["array"] + + inputs = processor( + raw_audio=audio_sample, + sampling_rate=processor.sampling_rate, + return_tensors="pt", + ).to("cpu") + + frame_size = model.config.frame_size + audio_codes = model.encode(inputs["input_values"]).audio_codes + + # streaming chunk by chunk + encoder_past_key_values = None + padding_cache = None + encoded_frames_list = [] + + for start in range(0, inputs["input_values"].shape[-1], frame_size): + input_values_chunk = inputs["input_values"][:, :, start : start + frame_size] + encoder_outputs = model.encode( + input_values_chunk, + padding_cache=padding_cache, + encoder_past_key_values=encoder_past_key_values, + use_streaming=True, + ) + encoder_past_key_values = encoder_outputs.encoder_past_key_values + padding_cache = encoder_outputs.padding_cache + encoded_frames_list.append(encoder_outputs.audio_codes) + + streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1) + + torch.testing.assert_close(streamed_audio_codes, audio_codes) + def test_integration(self): expected_rmses = { "8": 0.0018785292, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c482cd43d4..f404f99628 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3566,7 +3566,11 @@ class ModelTesterMixin: # TODO: if we can also check with `batch_size=1` without being flaky? for batch_size in [7]: # musicgen decoder models; TODO: find better abstraction - if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"): + if ( + model.__class__.__name__.startswith("Musicgen") + and hasattr(self.model_tester, "num_codebooks") + and not hasattr(model_eager, "text_encoder") + ): input_data_batch_size = batch_size * self.model_tester.num_codebooks else: input_data_batch_size = batch_size @@ -3626,7 +3630,7 @@ class ModelTesterMixin: if is_encoder_decoder: # musicgen encoder-decoder models; TODO: find better abstraction - if hasattr(self.model_tester, "num_codebooks"): + if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"): input_data_batch_size = batch_size * self.model_tester.num_codebooks else: input_data_batch_size = batch_size diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 7630dc2387..a930e63e99 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -619,7 +619,7 @@ ALL_FILE_TYPES = ( "processing", "image_processing", "video_processing", - "feature_extractor", + "feature_extraction", ) @@ -1137,7 +1137,7 @@ TYPE_TO_FILE_TYPE = { "VideoProcessor": "video_processing", "VideoProcessorInitKwargs": "video_processing", "FastImageProcessorKwargs": "image_processing*_fast", - "FeatureExtractor": "feature_extractor", + "FeatureExtractor": "feature_extraction", "ProcessorKwargs": "processing", "VideosKwargs": "processing", "ImagesKwargs": "processing",