From 6bdd4ec95264e5d8f219cfe4ee29ea9b42474bb7 Mon Sep 17 00:00:00 2001
From: eustlb <94853470+eustlb@users.noreply.github.com>
Date: Tue, 24 Jun 2025 18:01:15 +0200
Subject: [PATCH] Add kyutai stt (#38909)
* first draft
* cleaner version
* udpate tests + modeling
* add tests
* init
* udpate test_modeling_common
* fix tests
* csm Processor draft
* convertion update
* mimi cache padding convolutions draft
* mimi streaming udpates
* update mimi padding cache test
* udpate cache padding mimi test
* make style mimi
* updates generate moshi asr
* moshi asr integration tests (single + batched)
* update tests
* update conversion script
* good default sliding window value
* udpdate generate
* update test checkpoint
* nit
* fix mimi
* fix codec prefix
* revert
* revert
* update config
* update config
* unnecessary mimi input restriction
* remove delay in tokens
* remove _prepare_4d_causal_attention_mask_with_cache_position and _update_causal_mask
* test update
* modular update
* make style
* nit
* rename
* create codec model generation config at init
* remove delay
* max_new_tokens/length warning
* correct conv1 padding cache import for modular
* nit
* fix on encoder_past_key_values
* convert modular
* move frame_size to config
* move frame_size to config
* update test name
* handle first token is bos
* better handling of max_new_tokens
* fix
* fix batch size in test input prep
* update docstring
* convert modular
* make style
* make style
* add feature extractor
* correct modular convention name for feature_extraction file
* update convertion script
* doc processor
* update doc
* udpate init
* update model type
* fixes
* update tests
* fix
* make
* add doc
* nit
* fix
* doc
* auto mappings
* doc
* nit
* convert modular
* doc
* nit
* extend _keep_in_fp32_modules to enforce fp32
* renaming to stt
* doc update + test update
* doc fixes
* doc fix
* doc fix
* fix musicgen tests
* fix musicgen tests
* make style
* fix musicgen tests
* correct frame_rate config param for mimi
* update mimi test
* revert update mimi test
* enforce cpu test
* move cache init in cache class
* convert modular
* docstring update
* update model id
* feature_extractor -> feature_extraction (SEW)
* convert modular
* update model id
---
docs/source/en/_toctree.yml | 2 +
docs/source/en/model_doc/stt.md | 122 ++
src/transformers/modeling_utils.py | 5 +-
src/transformers/models/__init__.py | 1 +
.../models/auto/configuration_auto.py | 2 +
.../models/auto/feature_extraction_auto.py | 1 +
src/transformers/models/auto/modeling_auto.py | 2 +
.../models/auto/processing_auto.py | 1 +
.../models/mimi/configuration_mimi.py | 50 +-
src/transformers/models/mimi/modeling_mimi.py | 356 ++--
...actor_sew.py => feature_extraction_sew.py} | 0
src/transformers/models/stt/__init__.py | 29 +
.../configuration_kyutai_speech_to_text.py | 188 +++
.../convert_kyutai_speech_to_text_to_hf.py | 377 +++++
...eature_extraction_kyutai_speech_to_text.py | 237 +++
.../stt/modeling_kyutai_speech_to_text.py | 1434 +++++++++++++++++
.../stt/modular_kyutai_speech_to_text.py | 510 ++++++
.../stt/processing_kyutai_speech_to_text.py | 104 ++
.../models/kyutai_speech_to_text/__init__.py | 0
.../test_modeling_kyutai_speech_to_text.py | 704 ++++++++
tests/models/mimi/test_modeling_mimi.py | 63 +-
tests/test_modeling_common.py | 8 +-
utils/modular_model_converter.py | 4 +-
23 files changed, 4000 insertions(+), 200 deletions(-)
create mode 100644 docs/source/en/model_doc/stt.md
rename src/transformers/models/sew/{feature_extractor_sew.py => feature_extraction_sew.py} (100%)
create mode 100644 src/transformers/models/stt/__init__.py
create mode 100644 src/transformers/models/stt/configuration_kyutai_speech_to_text.py
create mode 100644 src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py
create mode 100644 src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py
create mode 100644 src/transformers/models/stt/modeling_kyutai_speech_to_text.py
create mode 100644 src/transformers/models/stt/modular_kyutai_speech_to_text.py
create mode 100644 src/transformers/models/stt/processing_kyutai_speech_to_text.py
create mode 100644 tests/models/kyutai_speech_to_text/__init__.py
create mode 100644 tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 6ebe8044ad..d8438a4165 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -843,6 +843,8 @@
title: GraniteSpeech
- local: model_doc/hubert
title: Hubert
+ - local: model_doc/stt
+ title: Kyutai Speech-To-Text
- local: model_doc/mctct
title: MCTCT
- local: model_doc/mimi
diff --git a/docs/source/en/model_doc/stt.md b/docs/source/en/model_doc/stt.md
new file mode 100644
index 0000000000..02428899df
--- /dev/null
+++ b/docs/source/en/model_doc/stt.md
@@ -0,0 +1,122 @@
+
+
+# Kyutai Speech-To-Text
+## Overview
+
+Kyutai STT is a speech-to-text model architecture based on the [Mimi codec](https://huggingface.co/docs/transformers/en/model_doc/mimi), which encodes audio into discrete tokens in a streaming fashion, and a [Moshi-like](https://huggingface.co/docs/transformers/en/model_doc/moshi) autoregressive decoder. Kyutai’s lab has released two model checkpoints:
+- [kyutai/stt-1b-en_fr](https://huggingface.co/kyutai/stt-1b-en_fr): a 1B-parameter model capable of transcribing both English and French
+- [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en): a 2.6B-parameter model focused solely on English, optimized for maximum transcription accuracy
+
+
+

+
+
+## Usage Tips
+
+### Inference
+
+```python
+import torch
+from datasets import load_dataset, Audio
+from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
+
+# 1. load the model and the processor
+torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+model_id = "kyutai/stt-2.6b-en"
+
+processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
+model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+
+# 2. load audio samples
+ds = load_dataset(
+ "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
+)
+ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+# 3. prepare the model inputs
+inputs = processor(
+ ds[0]["audio"]["array"],
+)
+inputs.to(torch_device)
+
+# 4. infer the model
+output_tokens = model.generate(**inputs)
+
+# 5. decode the generated tokens
+print(processor.batch_decode(output_tokens, skip_special_tokens=True))
+```
+
+### Batched Inference
+
+```python
+import torch
+from datasets import load_dataset, Audio
+from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
+
+# 1. load the model and the processor
+torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+model_id = "kyutai/stt-2.6b-en"
+
+processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
+model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+
+# 2. load audio samples
+ds = load_dataset(
+ "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
+)
+ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+# 3. prepare the model inputs
+audio_arrays = [ds[i]["audio"]["array"] for i in range(4)]
+inputs = processor(audio_arrays, return_tensors="pt", padding=True)
+inputs = inputs.to(torch_device)
+
+# 4. infer the model
+output_tokens = model.generate(**inputs)
+
+# 5. decode the generated tokens
+decoded_outputs = processor.batch_decode(output_tokens, skip_special_tokens=True)
+for output in decoded_outputs:
+ print(output)
+```
+
+This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
+The original code can be found [here](https://github.com/kyutai-labs/moshi).
+
+
+## KyutaiSpeechToTextConfig
+
+[[autodoc]] KyutaiSpeechToTextConfig
+
+## KyutaiSpeechToTextProcessor
+
+[[autodoc]] KyutaiSpeechToTextProcessor
+ - __call__
+
+## KyutaiSpeechToTextFeatureExtractor
+
+[[autodoc]] KyutaiSpeechToTextFeatureExtractor
+
+## KyutaiSpeechToTextForConditionalGeneration
+
+[[autodoc]] KyutaiSpeechToTextForConditionalGeneration
+ - forward
+ - generate
+
+## KyutaiSpeechToTextModel
+
+[[autodoc]] KyutaiSpeechToTextModel
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 4774a72df7..4f6095a3ed 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -4658,8 +4658,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
+ # Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
if model._keep_in_fp32_modules is not None and (
- torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
+ torch_dtype == torch.float16
+ or torch_dtype == torch.bfloat16
+ or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile(
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 504fcc2684..8d36068353 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -285,6 +285,7 @@ if TYPE_CHECKING:
from .squeezebert import *
from .stablelm import *
from .starcoder2 import *
+ from .stt import *
from .superglue import *
from .superpoint import *
from .swiftformer import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index d7529b2b63..54a285e3c6 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -322,6 +322,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("squeezebert", "SqueezeBertConfig"),
("stablelm", "StableLmConfig"),
("starcoder2", "Starcoder2Config"),
+ ("stt", "KyutaiSpeechToTextConfig"),
("superglue", "SuperGlueConfig"),
("superpoint", "SuperPointConfig"),
("swiftformer", "SwiftFormerConfig"),
@@ -707,6 +708,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("squeezebert", "SqueezeBERT"),
("stablelm", "StableLm"),
("starcoder2", "Starcoder2"),
+ ("stt", "KyutaiSpeechToText"),
("superglue", "SuperGlue"),
("superpoint", "SuperPoint"),
("swiftformer", "SwiftFormer"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index e7db1944d3..5754b3bc1b 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -91,6 +91,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("sew-d", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("speecht5", "SpeechT5FeatureExtractor"),
+ ("stt", "KyutaiSpeechToTextFeatureExtractor"),
("swiftformer", "ViTFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("swinv2", "ViTFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index b3224b7d46..cbfc0f7647 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -300,6 +300,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"),
+ ("stt", "KyutaiSpeechToTextModel"),
("superglue", "SuperGlueForKeypointMatching"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
@@ -1055,6 +1056,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("speecht5", "SpeechT5ForSpeechToText"),
+ ("stt", "KyutaiSpeechToTextForConditionalGeneration"),
("whisper", "WhisperForConditionalGeneration"),
]
)
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index b2e36bc4bc..478766e6ee 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -116,6 +116,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("speecht5", "SpeechT5Processor"),
+ ("stt", "KyutaiSpeechToTextProcessor"),
("trocr", "TrOCRProcessor"),
("tvlt", "TvltProcessor"),
("tvp", "TvpProcessor"),
diff --git a/src/transformers/models/mimi/configuration_mimi.py b/src/transformers/models/mimi/configuration_mimi.py
index a36b5e7101..b213359886 100644
--- a/src/transformers/models/mimi/configuration_mimi.py
+++ b/src/transformers/models/mimi/configuration_mimi.py
@@ -38,8 +38,8 @@ class MimiConfig(PretrainedConfig):
Args:
sampling_rate (`int`, *optional*, defaults to 24000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
- frame_rate (`float`, *optional*, defaults to 12.5):
- Framerate of the model.
+ frame_rate (`float`, *optional*):
+ Should be computed from the other parameters, yet kept for backward compatibility.
audio_channels (`int`, *optional*, defaults to 1):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
hidden_size (`int`, *optional*, defaults to 512):
@@ -111,6 +111,8 @@ class MimiConfig(PretrainedConfig):
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
+ use_streaming (`bool`, *optional*, defaults to `False`):
+ Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 250):
@@ -141,7 +143,7 @@ class MimiConfig(PretrainedConfig):
def __init__(
self,
sampling_rate=24_000,
- frame_rate=12.5,
+ frame_rate=None,
audio_channels=1,
hidden_size=512,
num_filters=64,
@@ -172,6 +174,7 @@ class MimiConfig(PretrainedConfig):
initializer_range=0.02,
norm_eps=1e-5,
use_cache=False,
+ use_streaming=False,
rope_theta=10000.0,
sliding_window=250,
attention_dropout=0.0,
@@ -180,7 +183,6 @@ class MimiConfig(PretrainedConfig):
**kwargs,
):
self.sampling_rate = sampling_rate
- self.frame_rate = frame_rate
self.audio_channels = audio_channels
self.hidden_size = hidden_size
self.num_filters = num_filters
@@ -209,6 +211,7 @@ class MimiConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.norm_eps = norm_eps
self.use_cache = use_cache
+ self.use_streaming = use_streaming
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.attention_dropout = attention_dropout
@@ -216,6 +219,14 @@ class MimiConfig(PretrainedConfig):
self.layer_scale_initial_scale = layer_scale_initial_scale
self.attention_bias = attention_bias
+ # Handle backward compatibility for frame_rate:
+ # If frame_rate is explicitly provided, use it (backward compatibility)
+ # Otherwise, compute it from other parameters (correctly)
+ if frame_rate is not None:
+ self._frame_rate = frame_rate
+ else:
+ self._frame_rate = None
+
if num_semantic_quantizers >= self.num_quantizers:
raise ValueError(
f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
@@ -233,5 +244,36 @@ class MimiConfig(PretrainedConfig):
# alias to num_quantizers
return self.num_quantizers
+ @property
+ def frame_size(self) -> int:
+ # 1. we need each encoder conv stride
+ # first conv
+ strides = [1]
+
+ # layer convs
+ for ratio in reversed(self.upsampling_ratios):
+ for j in range(self.num_residual_layers):
+ len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
+ strides.extend([1] * (len_kernel_sizes + 1))
+ if self.use_conv_shortcut: # skip connection
+ strides.append(1)
+
+ strides.append(ratio)
+
+ # last conv
+ strides.append(1)
+
+ # downsampling layer
+ strides.append(2)
+
+ return math.prod(strides)
+
+ @property
+ def frame_rate(self) -> float:
+ # handle backward compatibility
+ if self._frame_rate is not None:
+ return self._frame_rate
+ return self.sampling_rate / self.frame_size
+
__all__ = ["MimiConfig"]
diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py
index f1363f7897..221388f858 100644
--- a/src/transformers/models/mimi/modeling_mimi.py
+++ b/src/transformers/models/mimi/modeling_mimi.py
@@ -23,25 +23,20 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
-from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
-from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
-from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging
+from ...utils import ModelOutput, auto_docstring, logging
from .configuration_mimi import MimiConfig
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
-if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
-
logger = logging.get_logger(__name__)
@@ -78,6 +73,91 @@ class MimiOutput(ModelOutput):
decoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
+class MimiConv1dPaddingCache:
+ """
+ Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
+ See: https://arxiv.org/pdf/2005.06720 & https://arxiv.org/pdf/2204.07064
+
+ A padding cache is a list of cached partial hidden states for each convolution layer.
+ Hidden states are cached from the previous call to the MimiConv1d forward pass, given the padding size.
+ """
+
+ def __init__(
+ self,
+ num_layers: int,
+ per_layer_padding: list[int],
+ per_layer_padding_mode: list[str],
+ per_layer_in_channels: list[int],
+ ):
+ # ensure correct number of layers for each arg
+ from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}
+
+ if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
+ raise ValueError(
+ f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
+ )
+ elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode):
+ raise NotImplementedError(
+ "`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode"
+ )
+
+ self.per_layer_padding = per_layer_padding
+ self.per_layer_padding_mode = per_layer_padding_mode
+ self.per_layer_in_channels = per_layer_in_channels
+ self.per_layer_is_init = [True] * num_layers
+
+ self.padding_cache = [None] * num_layers
+
+ def update(self, hidden_states: torch.Tensor, layer_idx: int):
+ """
+ Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.
+
+ Parameters:
+ hidden_states (`torch.Tensor`):
+ The hidden states to be partially cached.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ Returns:
+ `torch.Tensor` or `None`, the current padding cache.
+ """
+ batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
+ padding = self.per_layer_padding[layer_idx]
+ padding_mode = self.per_layer_padding_mode[layer_idx]
+ in_channels = self.per_layer_in_channels[layer_idx]
+
+ if self.padding_cache[layer_idx] is None:
+ if padding_mode == "constant":
+ current_cache = torch.zeros(
+ batch_size,
+ in_channels,
+ padding,
+ device=device,
+ dtype=dtype,
+ )
+ elif padding_mode == "replicate":
+ current_cache = (
+ torch.ones(
+ batch_size,
+ in_channels,
+ padding,
+ device=device,
+ dtype=dtype,
+ )
+ * hidden_states[..., :1]
+ )
+ else:
+ current_cache = self.padding_cache[layer_idx]
+
+ # update the cache
+ if padding > 0:
+ padding_states = hidden_states[:, :, -padding:]
+ else:
+ padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device)
+ self.padding_cache[layer_idx] = padding_states
+
+ return current_cache
+
+
@dataclass
@auto_docstring
class MimiEncoderOutput(ModelOutput):
@@ -96,6 +176,7 @@ class MimiEncoderOutput(ModelOutput):
audio_codes: Optional[torch.LongTensor] = None
encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None
+ padding_cache: Optional[MimiConv1dPaddingCache] = None
@dataclass
@@ -130,12 +211,15 @@ class MimiConv1d(nn.Module):
stride: int = 1,
dilation: int = 1,
groups: int = 1,
- pad_mode=None,
+ pad_mode: Optional[str] = None,
bias: bool = True,
+ layer_idx: Optional[int] = None,
):
super().__init__()
self.causal = config.use_causal_conv
self.pad_mode = config.pad_mode if pad_mode is None else pad_mode
+ self.layer_idx = layer_idx
+ self.in_channels = in_channels
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
@@ -232,12 +316,20 @@ class MimiConv1d(nn.Module):
) // self.conv.stride[0] + 1
return output_lenght
- def forward(self, hidden_states):
+ def forward(self, hidden_states, padding_cache=None):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
- if self.causal:
+ if not self.causal and padding_cache is not None:
+ raise ValueError("`padding_cache` is not supported for non-causal convolutions.")
+
+ if self.causal and padding_cache is not None:
+ layer_padding_cache = padding_cache.update(hidden_states, self.layer_idx)
+ hidden_states = torch.cat([layer_padding_cache, hidden_states], dim=2)
+
+ elif self.causal:
# Left padding for causal
hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
+
else:
hidden_states = self._pad1d(
hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode
@@ -305,7 +397,6 @@ class MimiConvTranspose1d(nn.Module):
return hidden_states
-# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi
class MimiResnetBlock(nn.Module):
"""
Residual block from SEANet model as used by Mimi.
@@ -331,12 +422,21 @@ class MimiResnetBlock(nn.Module):
else:
self.shortcut = nn.Identity()
- def forward(self, hidden_states):
+ def forward(self, hidden_states, padding_cache=None):
residual = hidden_states
- for layer in self.block:
- hidden_states = layer(hidden_states)
- return self.shortcut(residual) + hidden_states
+ for layer in self.block:
+ if isinstance(layer, MimiConv1d):
+ hidden_states = layer(hidden_states, padding_cache=padding_cache)
+ else:
+ hidden_states = layer(hidden_states)
+
+ if isinstance(self.shortcut, MimiConv1d):
+ residual = self.shortcut(residual, padding_cache=padding_cache)
+ else:
+ residual = self.shortcut(residual)
+
+ return residual + hidden_states
class MimiEncoder(nn.Module):
@@ -370,10 +470,17 @@ class MimiEncoder(nn.Module):
self.layers = nn.ModuleList(model)
self._mimiconv1d_layer_names = mimiconv1d_layer_names
- # Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
- def forward(self, hidden_states):
+ # initialize layer_idx for MimiConv1d submodules, necessary for padding_cache
+ for layer_idx, layername in enumerate(self._mimiconv1d_layer_names):
+ conv_layer = self.get_submodule(layername)
+ setattr(conv_layer, "layer_idx", layer_idx)
+
+ def forward(self, hidden_states, padding_cache=None):
for layer in self.layers:
- hidden_states = layer(hidden_states)
+ if isinstance(layer, (MimiConv1d, MimiResnetBlock)):
+ hidden_states = layer(hidden_states, padding_cache=padding_cache)
+ else:
+ hidden_states = layer(hidden_states)
return hidden_states
@@ -1005,11 +1112,13 @@ class MimiTransformerModel(nn.Module):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- causal_mask = None
- if attention_mask is not None:
- causal_mask = self._update_causal_mask(
- attention_mask, hidden_states, cache_position, past_key_values, output_attentions
- )
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ )
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -1054,163 +1163,6 @@ class MimiTransformerModel(nn.Module):
attentions=all_self_attns,
)
- # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Mimi
- def _update_causal_mask(
- self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool = False,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and past_key_values is not None:
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
- if is_padding_right:
- raise ValueError(
- "You are attempting to perform batched generation with padding_side='right'"
- " this may lead to unexpected behaviour for Flash Attention version of Mimi. Make sure to "
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
- )
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
- if self.config._attn_implementation == "flex_attention":
- if isinstance(attention_mask, torch.Tensor):
- attention_mask = make_flex_block_causal_mask(attention_mask)
- return attention_mask
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if (
- self.config._attn_implementation == "sdpa"
- and not (using_static_cache or using_sliding_window_cache)
- and not output_attentions
- ):
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- sliding_window=self.config.sliding_window,
- is_training=self.training,
- ):
- return None
-
- dtype = input_tensor.dtype
- min_dtype = torch.finfo(dtype).min
- sequence_length = input_tensor.shape[1]
- # SlidingWindowCache or StaticCache
- if using_sliding_window_cache or using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- # DynamicCache or no cache
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- config=self.config,
- past_key_values=past_key_values,
- )
-
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
- return causal_mask
-
- @staticmethod
- # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Mimi
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- cache_position: torch.Tensor,
- batch_size: int,
- config: MimiConfig,
- past_key_values: Cache,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- config (`MimiConfig`):
- The model's configuration class
- past_key_values (`Cache`):
- The cache class that is being used currently to generate
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
- )
- diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
- -1, 1
- )
- text_config = config.get_text_config()
- if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
- # the check is needed to verify is current checkpoint was trained with sliding window or not
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
- sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
- cache_position.reshape(-1, 1) - text_config.sliding_window
- )
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
- causal_mask *= diagonal_attend_mask
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- if attention_mask.shape[-1] > target_length:
- attention_mask = attention_mask[:, :target_length]
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
- causal_mask.device
- )
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
-
class MimiDecoder(nn.Module):
"""SEANet decoder as used by Mimi."""
@@ -1269,7 +1221,7 @@ class MimiEuclideanCodebook(nn.Module):
def quantize(self, hidden_states):
# Projects each vector in `hidden_states` over the nearest centroid and return its index.
# `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
- dists = torch.cdist(hidden_states[None], self.embed[None], p=2)[0]
+ dists = torch.cdist(hidden_states[None].float(), self.embed[None].float(), p=2)[0]
embed_ind = dists.argmin(dim=-1)
return embed_ind
@@ -1476,6 +1428,7 @@ class MimiModel(MimiPreTrainedModel):
stride=2,
bias=False,
pad_mode="replicate",
+ layer_idx=len(self.encoder._mimiconv1d_layer_names),
)
self.upsample = MimiConvTranspose1d(
@@ -1512,12 +1465,17 @@ class MimiModel(MimiPreTrainedModel):
num_quantizers: int,
padding_mask: int,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ padding_cache: Optional[MimiConv1dPaddingCache] = None,
return_dict: Optional[bool] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale.
"""
- embeddings = self.encoder(input_values)
+
+ # TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported.
+ embeddings = self.encoder(input_values, padding_cache=padding_cache)
+
+ # TODO: @eustlb, convert the padding mask to attention mask.
encoder_outputs = self.encoder_transformer(
embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict
)
@@ -1526,11 +1484,11 @@ class MimiModel(MimiPreTrainedModel):
elif len(encoder_outputs) > 1:
past_key_values = encoder_outputs[1]
embeddings = encoder_outputs[0].transpose(1, 2)
- embeddings = self.downsample(embeddings)
+ embeddings = self.downsample(embeddings, padding_cache=padding_cache)
codes = self.quantizer.encode(embeddings, num_quantizers)
codes = codes.transpose(0, 1)
- return codes, past_key_values
+ return codes, past_key_values, padding_cache
def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
"""
@@ -1570,6 +1528,8 @@ class MimiModel(MimiPreTrainedModel):
padding_mask: Optional[torch.Tensor] = None,
num_quantizers: Optional[float] = None,
encoder_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ padding_cache: Optional[MimiConv1dPaddingCache] = None,
+ use_streaming: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], MimiEncoderOutput]:
"""
@@ -1598,6 +1558,7 @@ class MimiModel(MimiPreTrainedModel):
`codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
+ use_streaming = use_streaming if use_streaming is not None else self.config.use_streaming
num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers
@@ -1614,11 +1575,31 @@ class MimiModel(MimiPreTrainedModel):
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
- encoded_frames, encoder_past_key_values = self._encode_frame(
+ if use_streaming and padding_cache is None:
+ per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
+ for layer_name in self.encoder._mimiconv1d_layer_names:
+ per_layer_padding.append(self.encoder.get_submodule(layer_name).padding_total)
+ per_layer_padding_mode.append(self.encoder.get_submodule(layer_name).pad_mode)
+ per_layer_in_channels.append(self.encoder.get_submodule(layer_name).in_channels)
+
+ # downsample layer
+ per_layer_padding.append(self.downsample.padding_total)
+ per_layer_padding_mode.append(self.downsample.pad_mode)
+ per_layer_in_channels.append(self.downsample.in_channels)
+
+ padding_cache = MimiConv1dPaddingCache(
+ num_layers=len(self.encoder._mimiconv1d_layer_names) + 1,
+ per_layer_padding=per_layer_padding,
+ per_layer_padding_mode=per_layer_padding_mode,
+ per_layer_in_channels=per_layer_in_channels,
+ )
+
+ encoded_frames, encoder_past_key_values, padding_cache = self._encode_frame(
input_values,
num_quantizers,
padding_mask.bool(),
past_key_values=encoder_past_key_values,
+ padding_cache=padding_cache,
return_dict=return_dict,
)
@@ -1626,9 +1607,10 @@ class MimiModel(MimiPreTrainedModel):
return (
encoded_frames,
encoder_past_key_values,
+ padding_cache,
)
- return MimiEncoderOutput(encoded_frames, encoder_past_key_values)
+ return MimiEncoderOutput(encoded_frames, encoder_past_key_values, padding_cache)
def _decode_frame(
self,
diff --git a/src/transformers/models/sew/feature_extractor_sew.py b/src/transformers/models/sew/feature_extraction_sew.py
similarity index 100%
rename from src/transformers/models/sew/feature_extractor_sew.py
rename to src/transformers/models/sew/feature_extraction_sew.py
diff --git a/src/transformers/models/stt/__init__.py b/src/transformers/models/stt/__init__.py
new file mode 100644
index 0000000000..5823883c6c
--- /dev/null
+++ b/src/transformers/models/stt/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_kyutai_speech_to_text import *
+ from .feature_extraction_kyutai_speech_to_text import *
+ from .modeling_kyutai_speech_to_text import *
+ from .processing_kyutai_speech_to_text import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/stt/configuration_kyutai_speech_to_text.py b/src/transformers/models/stt/configuration_kyutai_speech_to_text.py
new file mode 100644
index 0000000000..f9ea11a5f4
--- /dev/null
+++ b/src/transformers/models/stt/configuration_kyutai_speech_to_text.py
@@ -0,0 +1,188 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.s
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class KyutaiSpeechToTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`KyutaiSpeechToTextForConditionalGeneration`].
+ It is used to instantiate a Kyutai Speech-to-Text model according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ 2.6b-en model.
+
+ e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ codebook_vocab_size (`int`, *optional*, defaults to 2049):
+ Vocabulary size of the codebook. Defines the number of different audio tokens that can be represented by each codebook.
+ vocab_size (`int`, *optional*, defaults to 4001):
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
+ `input_ids` passed when calling the model.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the layers and the pooler layer of the main decoder.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of decoder layers.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the main decoder block.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`.
+ max_position_embeddings (`int`, *optional*, defaults to 750):
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ rope_theta (`float`, *optional*, defaults to 100000.0):
+ The base period of the RoPE embeddings.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
+ The attention head dimension.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ sliding_window (`int`, *optional*, defaults to 375):
+ Sliding window attention window size. If not specified, will default to `3000`.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ ffn_dim (`int`, *optional*, defaults to 11264):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-08):
+ The epsilon used by the rms normalization layers.
+ num_codebooks (`int`, *optional*, defaults to 32):
+ The number of audio codebooks for each audio channels.
+ audio_bos_token_id (`int`, *optional*, defaults to 2048):
+ Beginning of stream token id for codebook tokens.
+ audio_pad_token_id (`int`, *optional*, defaults to 69569):
+ Padding token id for codebook tokens.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings.
+ pad_token_id (`int`, *optional*, defaults to 3):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 48000):
+ Beginning of stream token id for text tokens.
+ codec_config (`PretrainedConfig`, *optional*):
+ Configuration for the codec.
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+ - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the audio encoder config.
+ - **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the depth decoder config.
+
+
+ Example:
+ ```python
+ >>> from transformers import KyutaiSpeechToTextConfig, KyutaiSpeechToTextForConditionalGeneration
+
+ >>> # Initializing a KyutaiSpeechToTextConfig
+ >>> configuration = KyutaiSpeechToTextConfig()
+
+ >>> # Initializing a model
+ >>> model = KyutaiSpeechToTextForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ # not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
+ model_type = "stt"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {"codec_config": AutoConfig}
+
+ def __init__(
+ self,
+ codebook_vocab_size=2049,
+ vocab_size=4001,
+ hidden_size=2048,
+ num_hidden_layers=48,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ max_position_embeddings=750,
+ rope_theta=100000.0,
+ hidden_act="silu",
+ head_dim=None,
+ initializer_range=0.02,
+ use_cache=True,
+ sliding_window=375,
+ attention_dropout=0.0,
+ ffn_dim=11264,
+ rms_norm_eps=1e-8,
+ num_codebooks=32,
+ audio_bos_token_id=2048,
+ audio_pad_token_id=69569,
+ tie_word_embeddings=False,
+ pad_token_id=3,
+ bos_token_id=48000,
+ codec_config=None,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id, bos_token_id=bos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
+ )
+
+ if codec_config is None:
+ self.codec_config = AutoConfig.for_model("mimi")
+ logger.info("codec_config is None, using default audio encoder config.")
+ elif isinstance(codec_config, dict):
+ self.codec_config = AutoConfig.for_model(**codec_config)
+ elif isinstance(codec_config, PretrainedConfig):
+ self.codec_config = codec_config
+
+ self.num_codebooks = num_codebooks
+ self.frame_size = self.codec_config.frame_size
+
+ self.audio_bos_token_id = audio_bos_token_id
+ self.audio_pad_token_id = audio_pad_token_id
+ self.codebook_vocab_size = codebook_vocab_size
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ if ffn_dim % 2 == 1:
+ raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
+ self.ffn_dim = ffn_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ self.sliding_window = sliding_window
+
+
+__all__ = ["KyutaiSpeechToTextConfig"]
diff --git a/src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py b/src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py
new file mode 100644
index 0000000000..fe4a5a6bc6
--- /dev/null
+++ b/src/transformers/models/stt/convert_kyutai_speech_to_text_to_hf.py
@@ -0,0 +1,377 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import gc
+import os
+import re
+
+import safetensors.torch
+import sentencepiece
+import torch
+
+from transformers import (
+ KyutaiSpeechToTextConfig,
+ KyutaiSpeechToTextFeatureExtractor,
+ KyutaiSpeechToTextForConditionalGeneration,
+ KyutaiSpeechToTextProcessor,
+ PreTrainedTokenizerFast,
+)
+from transformers.convert_slow_tokenizer import MoshiConverter
+from transformers.utils.hub import cached_file
+
+
+# fmt: off
+MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ r"out_norm": r"norm",
+ r"gating\.linear_in": r"mlp.fc1",
+ r"gating\.linear_out": r"mlp.fc2",
+ r"self_attn\.out_proj": r"self_attn.o_proj.linear",
+ r"norm1": r"input_layernorm",
+ r"norm2": r"post_attention_layernorm",
+ r"layer_scale_1": r"self_attn_layer_scale",
+ r"layer_scale_2": r"mlp_layer_scale",
+ r"alpha": r"weight",
+}
+# fmt: on
+
+
+# fmt: off
+MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ r"conv\.conv\.conv": "conv",
+ r"convtr\.convtr\.convtr": "conv",
+ r"conv\.conv": "conv",
+ r"convtr\.convtr": "conv",
+ r"quantizer\.rvq_first\.vq": "quantizer.semantic_residual_vector_quantizer",
+ r"quantizer\.rvq_first": "quantizer.semantic_residual_vector_quantizer",
+ r"quantizer\.rvq_rest\.vq": "quantizer.acoustic_residual_vector_quantizer",
+ r"quantizer\.rvq_rest": "quantizer.acoustic_residual_vector_quantizer",
+ r"_codebook": "codebook",
+ r"_initialized": "initialized",
+ r"embedding_sum": "embed_sum",
+ r"encoder\.model": "encoder.layers",
+ r"decoder\.model": "decoder.layers",
+ r"encoder_transformer\.transformer": "encoder_transformer",
+ r"decoder_transformer\.transformer": "decoder_transformer",
+ r"linear1": "mlp.fc1",
+ r"linear2": "mlp.fc2",
+ r"self_attn\.out_proj": "self_attn.o_proj",
+ r"norm1": "input_layernorm",
+ r"norm2": "post_attention_layernorm",
+ r"layer_scale_1": "self_attn_layer_scale",
+ r"layer_scale_2": "mlp_layer_scale",
+}
+# fmt: on
+
+
+def permute_for_rope(input_tensor, n_heads, dim1, dim2):
+ """
+ When you go from the complex ROPE formulation to sin and cos one, you need
+ to permute the query and key weights (to avoid doing it on the fly)
+ """
+ return input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
+
+
+def convert_key(key, mapping):
+ for pattern, replacement in mapping.items():
+ key = re.sub(pattern, replacement, key)
+ return key
+
+
+def convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix="transformer."):
+ hidden_size = config.hidden_size
+ head_dim = config.head_dim
+ num_heads = int(config.hidden_size // config.head_dim)
+ num_key_value_heads = config.num_key_value_heads
+ key_value_head_dim = config.num_key_value_heads * head_dim
+
+ # concat embeddings
+ embed_tokens_weight = []
+ for i in range(32):
+ embed_tokens_weight.append(state_dict.pop(f"emb.{i}.weight"))
+
+ embed_tokens_weight = torch.cat(embed_tokens_weight, dim=0)
+ embed_tokens_weight = torch.cat([state_dict.pop("text_emb.weight"), embed_tokens_weight])
+ embed_tokens_weight = torch.cat([embed_tokens_weight, torch.zeros(1, config.hidden_size)], dim=0)
+ state_dict["embed_tokens.embed_tokens.weight"] = embed_tokens_weight
+
+ for key, value in list(state_dict.items()):
+ if unwanted_prefix is not None and unwanted_prefix in key:
+ new_key = key[len(unwanted_prefix) :]
+ else:
+ new_key = key
+
+ new_key = convert_key(new_key, MOSHI_ORIGINAL_TO_CONVERTED_KEY_MAPPING)
+
+ # Post-process the current_parameter.
+ if "alpha" in key:
+ state_dict[key] = state_dict[key].squeeze()
+
+ if "in_proj_weight" in new_key:
+ # split qkv into query key and value
+ mixed_qkv = state_dict.pop(key)
+ qkv_dim = mixed_qkv.size(0) // 3
+
+ query_layer = mixed_qkv[:qkv_dim]
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
+ value_layer = mixed_qkv[qkv_dim * 2 :]
+ state_dict[new_key.replace("in_proj_weight", "q_proj.linear.weight")] = permute_for_rope(
+ query_layer, num_heads, hidden_size, hidden_size
+ )
+ state_dict[new_key.replace("in_proj_weight", "k_proj.linear.weight")] = permute_for_rope(
+ key_layer, num_key_value_heads, key_value_head_dim, hidden_size
+ )
+
+ state_dict[new_key.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer
+ else:
+ state_dict[new_key] = state_dict.pop(key)
+
+ return state_dict
+
+
+def convert_mimi_state_dict(state_dict, config, unwanted_prefix=None):
+ hidden_size = config.hidden_size
+ head_dim = config.head_dim
+ num_heads = int(config.hidden_size // config.head_dim)
+ num_key_value_heads = config.num_key_value_heads
+ key_value_head_dim = config.num_key_value_heads * head_dim
+
+ for key, value in list(state_dict.items()):
+ if unwanted_prefix is not None and unwanted_prefix in key:
+ new_key = key[len(unwanted_prefix) :]
+ else:
+ new_key = key
+
+ new_key = convert_key(new_key, MIMI_ORIGINAL_TO_CONVERTED_KEY_MAPPING)
+
+ if "in_proj_weight" in new_key:
+ # split qkv into query key and value
+ mixed_qkv = state_dict.pop(key)
+ qkv_dim = mixed_qkv.size(0) // 3
+
+ query_layer = mixed_qkv[:qkv_dim]
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
+ value_layer = mixed_qkv[qkv_dim * 2 :]
+
+ state_dict[new_key.replace("in_proj_weight", "q_proj.weight")] = permute_for_rope(
+ query_layer, num_heads, hidden_size, hidden_size
+ )
+ state_dict[new_key.replace("in_proj_weight", "k_proj.weight")] = permute_for_rope(
+ key_layer, num_key_value_heads, key_value_head_dim, hidden_size
+ )
+ state_dict[new_key.replace("in_proj_weight", "v_proj.weight")] = value_layer
+ else:
+ state_dict[new_key] = state_dict.pop(key)
+
+ return state_dict
+
+
+def write_model(
+ input_path_or_repo,
+ model_name,
+ codec_model_path_or_repo,
+ codec_model_name,
+ output_dir,
+ safe_serialization=True,
+ unwanted_prefix="transformer.",
+):
+ print("Converting the model.")
+ os.makedirs(output_dir, exist_ok=True)
+
+ config = KyutaiSpeechToTextConfig()
+ config.use_cache = True
+ config.codec_config.sliding_window = 250
+
+ model_path = cached_file(
+ input_path_or_repo,
+ model_name,
+ )
+
+ codec_path = cached_file(
+ codec_model_path_or_repo,
+ codec_model_name,
+ )
+
+ print(f"Fetching all parameters from the checkpoint at {model_path}...")
+ state_dict = safetensors.torch.load_file(model_path)
+
+ print(f"Fetching all parameters from the checkpoint at {codec_path}...")
+ codec_state_dict = safetensors.torch.load_file(codec_path)
+
+ print("Converting model...")
+ # -----------------------
+ # convert parameter names
+ # -----------------------
+ state_dict = convert_kyutai_speech_to_text_state_dict(state_dict, config, unwanted_prefix=unwanted_prefix)
+ codec_state_dict = convert_mimi_state_dict(codec_state_dict, config.codec_config, unwanted_prefix=None)
+
+ # -------------------------
+ # load the weights and save
+ # -------------------------
+ print("Loading the checkpoint in a Moshi ASR model.")
+ with torch.device("meta"):
+ model = KyutaiSpeechToTextForConditionalGeneration(config)
+
+ linear_weight = state_dict.pop("text_linear.weight")
+ model.model.load_state_dict(state_dict, strict=True, assign=True)
+
+ linear_weight = torch.cat([linear_weight, torch.zeros(1, config.hidden_size)])
+ model.lm_head.load_state_dict({"weight": linear_weight}, strict=True, assign=True)
+
+ model.codec_model.load_state_dict(codec_state_dict, strict=True, assign=True)
+
+ print("Checkpoint loaded successfully.")
+ del model.config._name_or_path
+ del model.config.codec_config._name_or_path
+
+ # default generation config
+ model.generation_config._from_model_config = False
+ model.generation_config.audio_window_size = 1
+ model.generation_config.cache_implementation = "sliding_window"
+
+ model.codec_model.generation_config._from_model_config = False
+ model.codec_model.generation_config.cache_implementation = "sliding_window"
+ model.codec_model.generation_config.use_cache = True
+
+ print("Saving the model.")
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ del state_dict, model
+
+ # Safety check: reload the converted model
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
+ output_dir, torch_dtype=torch.bfloat16, device_map="auto"
+ )
+ print("Model reloaded successfully.")
+
+
+def write_processor(
+ input_path_or_repo,
+ tokenizer_model_name,
+ codec_model_path_or_repo,
+ output_dir,
+ audio_delay_seconds,
+ audio_silence_prefix_seconds,
+):
+ tokenizer_path = cached_file(
+ input_path_or_repo,
+ tokenizer_model_name,
+ )
+
+ tokenizer = MoshiConverter(tokenizer_path).converted()
+ original_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
+
+ tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ chat_template=None,
+ unk_token="",
+ model_input_names=["input_ids", "attention_mask"],
+ clean_up_tokenization_spaces=False,
+ bos_token_id=original_tokenizer.bos_id(),
+ eos_token_id=original_tokenizer.eos_id(),
+ pad_token_id=original_tokenizer.pad_id(),
+ )
+
+ feature_extractor = KyutaiSpeechToTextFeatureExtractor(
+ audio_delay_seconds=audio_delay_seconds,
+ audio_silence_prefix_seconds=audio_silence_prefix_seconds,
+ )
+
+ processor = KyutaiSpeechToTextProcessor(feature_extractor, tokenizer)
+ processor.save_pretrained(output_dir)
+ print(f"Processor saved successfully to {output_dir}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert Moshi ASR weights to HuggingFace format")
+ parser.add_argument(
+ "--input_path_or_repo",
+ type=str,
+ required=True,
+ help="Path or repo containing Moshi ASR weights",
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ required=True,
+ help="Name of the model in input_path_or_repo",
+ )
+ parser.add_argument(
+ "--tokenizer_model_name",
+ type=str,
+ required=True,
+ help="Name of the tokenizer model in input_path_or_repo",
+ )
+ parser.add_argument(
+ "--codec_model_path_or_repo",
+ type=str,
+ required=True,
+ help="Path or repo containing the Mimi weights",
+ )
+ parser.add_argument(
+ "--mimi_name",
+ type=str,
+ required=True,
+ help="Name of the Mimi model in codec_model_path_or_repo",
+ )
+ parser.add_argument(
+ "--preprocessor_model_path_or_repo",
+ type=str,
+ required=True,
+ help="Path or repo containing the preprocessor config",
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Location to write HF model and tokenizer",
+ )
+ parser.add_argument(
+ "--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--audio_delay_seconds",
+ type=float,
+ required=True,
+ help="Audio delay in seconds to add to the right of the input",
+ )
+ parser.add_argument(
+ "--audio_silence_prefix_seconds",
+ type=float,
+ required=True,
+ help="Audio silence prefix in seconds to add to the left of the input",
+ )
+ args = parser.parse_args()
+
+ write_model(
+ args.input_path_or_repo,
+ args.model_name,
+ args.codec_model_path_or_repo,
+ args.mimi_name,
+ args.output_dir,
+ safe_serialization=args.safe_serialization,
+ )
+
+ write_processor(
+ args.input_path_or_repo,
+ args.tokenizer_model_name,
+ args.preprocessor_model_path_or_repo,
+ args.output_dir,
+ args.audio_delay_seconds,
+ args.audio_silence_prefix_seconds,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py b/src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py
new file mode 100644
index 0000000000..94ddb15daa
--- /dev/null
+++ b/src/transformers/models/stt/feature_extraction_kyutai_speech_to_text.py
@@ -0,0 +1,237 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs an KyutaiSpeechToText feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 1):
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
+ sampling_rate (`int`, *optional*, defaults to 24000):
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ The value that is used to fill the padding values.
+ chunk_length_s (`float`, *optional*):
+ If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
+ overlap (`float`, *optional*):
+ Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
+ formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
+ audio_delay_seconds (`float`, *optional*, defaults to 0.0):
+ The delay in seconds to add after the audio (right padding).
+ audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
+ The silence prefix in seconds to add before the audio (left padding).
+ """
+
+ model_input_names = ["input_values", "padding_mask"]
+
+ def __init__(
+ self,
+ feature_size: int = 1,
+ sampling_rate: int = 24000,
+ padding_value: float = 0.0,
+ chunk_length_s: Optional[float] = None,
+ overlap: Optional[float] = None,
+ audio_delay_seconds: Optional[float] = 0.0,
+ audio_silence_prefix_seconds: Optional[float] = 0.0,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+ self.chunk_length_s = chunk_length_s
+ self.overlap = overlap
+ self.audio_delay_seconds = audio_delay_seconds
+ self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
+
+ # This is a property because you might want to change the chunk_length_s on the fly
+ @property
+ def chunk_length(self) -> Optional[int]:
+ if self.chunk_length_s is None:
+ return None
+ else:
+ return int(self.chunk_length_s * self.sampling_rate)
+
+ # This is a property because you might want to change the chunk_length_s on the fly
+ @property
+ def chunk_stride(self) -> Optional[int]:
+ if self.chunk_length_s is None or self.overlap is None:
+ return None
+ else:
+ return max(1, int((1.0 - self.overlap) * self.chunk_length))
+
+ def __call__(
+ self,
+ raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
+ truncation: Optional[bool] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s).
+
+ Args:
+ raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
+ (`feature_size = 2`).
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, *optional*, defaults to `False`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ """
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ if padding and truncation:
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
+ elif padding is None:
+ # by default let's pad the inputs
+ padding = True
+
+ is_batched = bool(
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
+ raw_audio = raw_audio.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_audio = [np.asarray(raw_audio).T]
+
+ # verify inputs are valid
+ for idx, example in enumerate(raw_audio):
+ if example.ndim > 2:
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
+ if self.feature_size == 1 and example.ndim != 1:
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
+ if self.feature_size == 2 and example.shape[-1] != 2:
+ raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
+
+ padded_inputs = None
+ input_values = BatchFeature({"input_values": raw_audio})
+ if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
+ if truncation:
+ max_length = min(array.shape[0] for array in raw_audio)
+ nb_step = int(np.floor(max_length / self.chunk_stride))
+ max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
+ elif padding:
+ max_length = max(array.shape[0] for array in raw_audio)
+ nb_step = int(np.ceil(max_length / self.chunk_stride))
+ max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
+ padding = "max_length"
+ else:
+ padded_inputs = input_values
+
+ # normal padding on batch
+ if padded_inputs is None:
+ padded_inputs = self.pad(
+ input_values,
+ max_length=max_length,
+ truncation=truncation,
+ padding=padding,
+ return_attention_mask=padding,
+ )
+
+ if padding:
+ padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
+
+ # now let's padd left and right
+ pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
+ pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
+ padded_inputs["input_values"] = np.pad(
+ padded_inputs["input_values"],
+ ((0, 0), (pad_left, pad_right)),
+ mode="constant",
+ constant_values=0.0,
+ )
+ if padding:
+ padded_inputs["padding_mask"] = np.pad(
+ padded_inputs["padding_mask"],
+ ((0, 0), (pad_left, pad_right)),
+ mode="constant",
+ constant_values=0,
+ )
+
+ input_values = []
+ for example in padded_inputs.pop("input_values"):
+ if self.feature_size == 1:
+ example = example[..., None]
+ input_values.append(example.T)
+
+ padded_inputs["input_values"] = input_values
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
+
+
+__all__ = ["KyutaiSpeechToTextFeatureExtractor"]
diff --git a/src/transformers/models/stt/modeling_kyutai_speech_to_text.py b/src/transformers/models/stt/modeling_kyutai_speech_to_text.py
new file mode 100644
index 0000000000..7a86cd440c
--- /dev/null
+++ b/src/transformers/models/stt/modeling_kyutai_speech_to_text.py
@@ -0,0 +1,1434 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import types
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
+from ...generation import GenerationConfig, GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+ flash_attn_supports_top_left_mask,
+ is_flash_attn_available,
+)
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
+from ..auto import AutoModel
+from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+class KyutaiSpeechToTextRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ # Ignore copy
+ def forward(self, x):
+ output = self._norm(x.float())
+ output = output * self.weight.float()
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class KyutaiSpeechToTextFlexibleLinear(nn.Module):
+ def __init__(self, input_size, output_size, num_layers):
+ super().__init__()
+ # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size)
+ self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size))
+
+ def forward(self, x, layer_idx=None):
+ """
+ `KyutaiSpeechToTextFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it.
+ In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence.
+
+ For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`.
+ If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer.
+ But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`.
+
+
+ Args:
+ x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)`
+ layer_idx (`torch.Tensor`, *optional*):
+ Can be used to specify which codebook's layers(s) to use.
+ If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights.
+ But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`.
+ """
+
+ # Use torch.gather to select the corresponding weights for each sample
+ # (codebooks, output_size, hidden_size)
+ selected_weights = torch.index_select(self.weight, 0, layer_idx) if layer_idx is not None else self.weight
+
+ # (1, codebooks, hidden_size, output_size)
+ selected_weights = selected_weights.transpose(1, 2)[None, :, :, :]
+
+ # (batch_size, codebooks, 1, hidden_size) x (1, codebooks, hidden_size, output_size)
+ # -> (batch_size, codebooks, 1, output_size)
+ x = torch.matmul(x[:, :, None, :], selected_weights)
+
+ # (batch_size, codebooks, output_size)
+ return x.squeeze(2)
+
+
+@auto_docstring
+class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel):
+ config_class = KyutaiSpeechToTextConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["KyutaiSpeechToTextDecoderLayer", "MimiTransformerLayer"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ main_input_name = "input_ids"
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, KyutaiSpeechToTextFlexibleLinear):
+ module.weight.data.normal_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, KyutaiSpeechToTextRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+class KyutaiSpeechToTextConv1dPaddingCache:
+ """
+ Padding cache for KyutaiSpeechToTextConv1d causal convolutions in order to support streaming via cache padding.
+ See: https://arxiv.org/pdf/2005.06720 & https://arxiv.org/pdf/2204.07064
+
+ A padding cache is a list of cached partial hidden states for each convolution layer.
+ Hidden states are cached from the previous call to the KyutaiSpeechToTextConv1d forward pass, given the padding size.
+ """
+
+ def __init__(
+ self,
+ num_layers: int,
+ per_layer_padding: list[int],
+ per_layer_padding_mode: list[str],
+ per_layer_in_channels: list[int],
+ ):
+ # ensure correct number of layers for each arg
+ from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}
+
+ if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
+ raise ValueError(
+ f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
+ )
+ elif not all(mode in ["constant", "replicate"] for mode in per_layer_padding_mode):
+ raise NotImplementedError(
+ "`padding_cache` is not supported for convolutions using other than `constant` or `replicate` padding mode"
+ )
+
+ self.per_layer_padding = per_layer_padding
+ self.per_layer_padding_mode = per_layer_padding_mode
+ self.per_layer_in_channels = per_layer_in_channels
+ self.per_layer_is_init = [True] * num_layers
+
+ self.padding_cache = [None] * num_layers
+
+ def update(self, hidden_states: torch.Tensor, layer_idx: int):
+ """
+ Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.
+
+ Parameters:
+ hidden_states (`torch.Tensor`):
+ The hidden states to be partially cached.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ Returns:
+ `torch.Tensor` or `None`, the current padding cache.
+ """
+ batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
+ padding = self.per_layer_padding[layer_idx]
+ padding_mode = self.per_layer_padding_mode[layer_idx]
+ in_channels = self.per_layer_in_channels[layer_idx]
+
+ if self.padding_cache[layer_idx] is None:
+ if padding_mode == "constant":
+ current_cache = torch.zeros(
+ batch_size,
+ in_channels,
+ padding,
+ device=device,
+ dtype=dtype,
+ )
+ elif padding_mode == "replicate":
+ current_cache = (
+ torch.ones(
+ batch_size,
+ in_channels,
+ padding,
+ device=device,
+ dtype=dtype,
+ )
+ * hidden_states[..., :1]
+ )
+ else:
+ current_cache = self.padding_cache[layer_idx]
+
+ # update the cache
+ if padding > 0:
+ padding_states = hidden_states[:, :, -padding:]
+ else:
+ padding_states = torch.empty(batch_size, in_channels, padding, dtype=dtype, device=device)
+ self.padding_cache[layer_idx] = padding_states
+
+ return current_cache
+
+
+class KyutaiSpeechToTextEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1,
+ config.hidden_size,
+ padding_idx=config.audio_pad_token_id,
+ )
+ audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size
+ audio_tokens_offsets += config.vocab_size
+ audio_tokens_offsets = nn.functional.pad(
+ audio_tokens_offsets, (1, 0)
+ ) # pad one 0 to the left for the text token
+ self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False)
+
+ def forward(self, input_ids):
+ input_ids = torch.where(
+ input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets
+ )
+ inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = inputs_embeds.sum(dim=2)
+ return inputs_embeds
+
+
+class KyutaiSpeechToTextLinear(nn.Module):
+ def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False):
+ super().__init__()
+
+ self.use_flexible_linear = use_flexible_linear
+
+ if not use_flexible_linear:
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
+ else:
+ self.linear = KyutaiSpeechToTextFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks)
+
+ def forward(self, x, layer_idx=None):
+ if self.use_flexible_linear:
+ return self.linear(x, layer_idx)
+ else:
+ return self.linear(x)
+
+
+class KyutaiSpeechToTextRotaryEmbedding(nn.Module):
+ def __init__(self, config: KyutaiSpeechToTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class KyutaiSpeechToTextGatingMLP(nn.Module):
+ def __init__(self, config, use_flexible_linear=False):
+ super().__init__()
+
+ self.activation_fn = ACT2FN[config.hidden_act]
+ ffn_dim = config.ffn_dim
+ hidden_size = config.hidden_size
+ num_layers = config.num_codebooks if use_flexible_linear else 1
+ if num_layers == 1:
+ self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False)
+ self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False)
+ else:
+ self.fc1 = KyutaiSpeechToTextFlexibleLinear(hidden_size, ffn_dim, num_layers)
+ self.fc2 = KyutaiSpeechToTextFlexibleLinear(ffn_dim // 2, hidden_size, num_layers)
+
+ def forward(self, hidden_states: torch.Tensor, layer_idx: Optional[int] = None) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1)
+ hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :]
+ hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx)
+ return hidden_states
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class KyutaiSpeechToTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: KyutaiSpeechToTextConfig,
+ layer_idx: Optional[int] = None,
+ use_flexible_linear=False,
+ use_rope=True,
+ ):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.is_causal = True
+ self.scaling = 1 / math.sqrt(self.head_dim)
+
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = KyutaiSpeechToTextLinear(
+ self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear
+ )
+ self.k_proj = KyutaiSpeechToTextLinear(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear
+ )
+ self.v_proj = KyutaiSpeechToTextLinear(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear
+ )
+ self.o_proj = KyutaiSpeechToTextLinear(
+ self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear
+ )
+
+ # rotary embeddings are not used in the depth decoder
+ self.rotary_emb = None
+ if use_rope:
+ self.rope_theta = config.rope_theta
+ self.rotary_emb = KyutaiSpeechToTextRotaryEmbedding(config)
+
+ # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward
+ # no longer copied after attention refactors
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
+ key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
+ value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.rotary_emb is not None: # Ignore copy
+ cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = (
+ {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if self.rotary_emb is not None
+ else {"cache_position": cache_position}
+ ) # Ignore copy
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->KyutaiSpeechToText
+# TODO cyril: modular
+class KyutaiSpeechToTextFlashAttention2(KyutaiSpeechToTextAttention):
+ """
+ KyutaiSpeechToText flash attention module. This module inherits from `KyutaiSpeechToTextAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
+ key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
+ value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.rotary_emb is not None: # Ignore copy
+ cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = (
+ {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if self.rotary_emb is not None
+ else {"cache_position": cache_position}
+ ) # Ignore copy
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (KyutaiSpeechToTextRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = (
+ torch.get_autocast_dtype(device_type)
+ if hasattr(torch, "get_autocast_dtype")
+ else torch.get_autocast_gpu_dtype()
+ )
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->KyutaiSpeechToText
+# TODO cyril: modular
+class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention):
+ """
+ KyutaiSpeechToText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `KyutaiSpeechToTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from KyutaiSpeechToTextAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "KyutaiSpeechToTextModel is using KyutaiSpeechToTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
+ key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
+ value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.rotary_emb is not None: # Ignore copy
+ cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = (
+ {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if self.rotary_emb is not None
+ else {"cache_position": cache_position}
+ ) # Ignore copy
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
+
+ return attn_output, None, past_key_value
+
+
+STT_ATTENTION_CLASSES = {
+ "eager": KyutaiSpeechToTextAttention,
+ "flash_attention_2": KyutaiSpeechToTextFlashAttention2,
+ "sdpa": KyutaiSpeechToTextSdpaAttention,
+}
+
+
+class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: KyutaiSpeechToTextConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.use_flexible_linear = use_flexible_linear
+
+ self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation](
+ config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
+ )
+
+ self.mlp = KyutaiSpeechToTextGatingMLP(config, use_flexible_linear)
+ self.input_layernorm = KyutaiSpeechToTextRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = KyutaiSpeechToTextRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.sliding_window = config.sliding_window
+
+ self._attn_implementation = config._attn_implementation
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = (
+ self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position)
+ )
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+@auto_docstring
+class KyutaiSpeechToTextModel(KyutaiSpeechToTextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = KyutaiSpeechToTextEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [
+ KyutaiSpeechToTextDecoderLayer(config, layer_idx, use_flexible_linear=False)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = KyutaiSpeechToTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ return_legacy_cache = False # noqa: F841
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True # noqa: F841
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = None
+ if attention_mask is not None:
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
+ )
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and past_key_values is not None:
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of KyutaiSpeechToText. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not (using_static_cache or using_sliding_window_cache)
+ and not output_attentions
+ ):
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ sliding_window=self.config.sliding_window,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ # SlidingWindowCache or StaticCache
+ if using_sliding_window_cache or using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ # DynamicCache or no cache
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ config=self.config,
+ past_key_values=past_key_values,
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ config: KyutaiSpeechToTextConfig,
+ past_key_values: Cache,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ config (`KyutaiSpeechToTextConfig`):
+ The model's configuration class
+ past_key_values (`Cache`):
+ The cache class that is being used currently to generate
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
+ -1, 1
+ )
+ text_config = config.get_text_config()
+ if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
+ sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
+ cache_position.reshape(-1, 1) - text_config.sliding_window
+ )
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
+ causal_mask *= diagonal_attend_mask
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.shape[-1] > target_length:
+ attention_mask = attention_mask[:, :target_length]
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+ return causal_mask
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring
+class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ _keep_in_fp32_modules = ["codec_model"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = KyutaiSpeechToTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.codec_model = AutoModel.from_config(config.codec_config)
+
+ # we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None
+ # yet the codec_model needs a generation config to initalize it's cache for streaming inference
+ # we therefore initialize a generation config for the codec model
+ self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from datasets import load_dataset, Audio
+ >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
+
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+ >>> model_id = "kyutai/stt-2.6b-en"
+
+ >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
+ >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+
+ >>> ds = load_dataset(
+ ... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
+ ... )
+
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+ >>> inputs = processor(
+ ... ds[0]["audio"]["array"],
+ ... )
+ >>> inputs.to(torch_device)
+
+ >>> output_tokens = model.generate(**inputs)
+ >>> print(processor.batch_decode(output_tokens, skip_special_tokens=True))
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def _prepare_generation_config(self, *args, **kwargs):
+ generation_config, model_kwargs = super()._prepare_generation_config(*args, **kwargs)
+ # this should be passed to the model kwargs for the input preparation
+ model_kwargs["audio_window_size"] = (
+ generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None
+ )
+ return generation_config, model_kwargs
+
+ def _prepare_model_inputs(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[torch.Tensor] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
+ inputs, input_name, model_kwargs = super()._prepare_model_inputs(
+ inputs=inputs,
+ bos_token_id=bos_token_id,
+ model_kwargs=model_kwargs,
+ )
+
+ audio_window_size = model_kwargs.get("audio_window_size", None)
+ if audio_window_size is None:
+ audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item()
+ model_kwargs["audio_window_size"] = audio_window_size
+
+ batch_size = inputs.shape[0]
+ device = inputs.device
+
+ # initialize audio tokens
+ model_kwargs["audio_tokens"] = torch.zeros(
+ (batch_size, audio_window_size, self.config.num_codebooks),
+ device=device,
+ dtype=torch.long,
+ )
+ model_kwargs["current_window"] = (
+ torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous()
+ )
+
+ # let's use generate's cache preparation to prepare the cache for the codec model
+ temporary_model_kwargs = {}
+
+ # monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin
+ # Add cache-related methods from GenerationMixin to codec model
+ cache_methods = [
+ "_prepare_cache_for_generation",
+ "_get_cache",
+ "_supports_default_dynamic_cache",
+ "_get_layer_device_map_for_cache_init",
+ ]
+ for method in cache_methods:
+ setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
+
+ self.codec_model._prepare_cache_for_generation(
+ generation_config=self.codec_model.generation_config,
+ model_kwargs=temporary_model_kwargs,
+ assistant_model=None,
+ batch_size=batch_size,
+ max_cache_length=self.config.codec_config.sliding_window,
+ device=device,
+ )
+
+ if "past_key_values" in temporary_model_kwargs:
+ model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"]
+
+ # initialize the padding cache for the codec model
+ per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
+ for layer_name in self.codec_model.encoder._mimiconv1d_layer_names:
+ per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total)
+ per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode)
+ per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels)
+
+ # downsample layer
+ per_layer_padding.append(self.codec_model.downsample.padding_total)
+ per_layer_padding_mode.append(self.codec_model.downsample.pad_mode)
+ per_layer_in_channels.append(self.codec_model.downsample.in_channels)
+
+ model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache(
+ num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1,
+ per_layer_padding=per_layer_padding,
+ per_layer_padding_mode=per_layer_padding_mode,
+ per_layer_in_channels=per_layer_in_channels,
+ )
+
+ return inputs, input_name, model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ *args,
+ audio_tokens: Optional[torch.LongTensor] = None,
+ input_values: Optional[torch.FloatTensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ audio_window_size: Optional[int] = None,
+ current_window: Optional[tuple[int, int]] = None,
+ encoder_past_key_values: Optional[Cache] = None,
+ padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
+
+ if input_values is not None:
+ cache_position = model_inputs["cache_position"]
+ start, end = current_window[0]
+
+ # first cache position is for bos token, so we need to offset by -1
+ if cache_position[-1] - 1 >= end:
+ # we need to encode the new audio tokens
+ with torch.no_grad():
+ input_values_start_idx = start * self.config.frame_size
+ input_values_end_idx = (start + audio_window_size) * self.config.frame_size
+ current_input_values = input_values[..., input_values_start_idx:input_values_end_idx]
+ codec_model_output = self.codec_model.encode(
+ current_input_values,
+ encoder_past_key_values=encoder_past_key_values,
+ padding_cache=padding_cache,
+ )
+ new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2)
+
+ audio_tokens.copy_(new_audio_tokens)
+
+ start = end.clone()
+ end = end + audio_window_size
+ current_window.copy_(
+ torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1)
+ )
+
+ # first cache position is for bos token, so we need to offset by -1
+ current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0)
+ current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :]
+
+ current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id
+
+ input_ids = model_inputs.pop("input_ids")
+ input_ids = torch.cat(
+ [input_ids.unsqueeze(2), current_audio_tokens],
+ dim=2,
+ )
+ model_inputs["input_ids"] = input_ids
+
+ return model_inputs
+
+ # TODO: @eustlb, this should be standardized
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ if kwargs.get("output_loading_info", False):
+ model, loading_info = super().from_pretrained(*args, **kwargs)
+ else:
+ model = super().from_pretrained(*args, **kwargs)
+
+ # copy depth decoder generation conf attr to the depth decoder generation config
+ prefix = "codec_"
+ prefix_len = len(prefix)
+ codec_model_attrs = {
+ attr[prefix_len:]: value
+ for attr, value in vars(model.generation_config).items()
+ if attr.startswith(prefix)
+ }
+
+ vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs})
+
+ # remove the depth decoder generation conf attr from the model generation config
+ for attr in codec_model_attrs:
+ delattr(model.generation_config, prefix + attr)
+
+ if "output_loading_info" in kwargs:
+ return model, loading_info
+ else:
+ return model
+
+ # TODO: @eustlb, this should be standardized
+ def save_pretrained(self, *args, **kwargs):
+ prefix = "codec_"
+ codec_model_attrs = self.codec_model.generation_config.to_diff_dict()
+ codec_model_attrs.pop("transformers_version", None)
+ for attr, value in codec_model_attrs.items():
+ setattr(self.generation_config, prefix + attr, value)
+
+ super().save_pretrained(*args, **kwargs)
+
+ def generate(self, *args, **kwargs):
+ r"""
+ This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information.
+ """
+ max_new_tokens = kwargs.pop("max_new_tokens", None)
+ input_values = kwargs.get("input_values")
+
+ # TODO: @eustlb, we should have per-batch-idx values
+ # here we do not use padding_mask to be aligned to what's done in the original codebase
+ max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size
+
+ if max_new_tokens is None or max_new_tokens > max_audio_frames:
+ if max_new_tokens is not None:
+ logger.warning(
+ f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})."
+ f"Setting `max_new_tokens` to {max_audio_frames}."
+ )
+ max_new_tokens = max_audio_frames
+
+ return super().generate(
+ *args,
+ max_new_tokens=max_new_tokens,
+ **kwargs,
+ )
+
+
+__all__ = [
+ "KyutaiSpeechToTextPreTrainedModel",
+ "KyutaiSpeechToTextModel",
+ "KyutaiSpeechToTextForConditionalGeneration",
+]
diff --git a/src/transformers/models/stt/modular_kyutai_speech_to_text.py b/src/transformers/models/stt/modular_kyutai_speech_to_text.py
new file mode 100644
index 0000000000..8cc0c9d2a7
--- /dev/null
+++ b/src/transformers/models/stt/modular_kyutai_speech_to_text.py
@@ -0,0 +1,510 @@
+# coding=utf-8
+# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import types
+from typing import Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache
+from ...feature_extraction_utils import BatchFeature
+from ...generation import GenerationConfig, GenerationMixin
+from ...modeling_utils import PreTrainedModel
+from ...utils import PaddingStrategy, TensorType, logging
+from ..auto import AutoModel
+from ..encodec.feature_extraction_encodec import EncodecFeatureExtractor
+from ..llama.modeling_llama import LlamaForCausalLM
+from ..mimi.modeling_mimi import MimiConv1dPaddingCache
+from ..moshi.modeling_moshi import MoshiModel, MoshiPreTrainedModel
+
+
+logger = logging.get_logger(__name__)
+
+
+class KyutaiSpeechToTextFeatureExtractor(EncodecFeatureExtractor):
+ r"""
+ Constructs an KyutaiSpeechToText feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 1):
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
+ sampling_rate (`int`, *optional*, defaults to 24000):
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ The value that is used to fill the padding values.
+ chunk_length_s (`float`, *optional*):
+ If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
+ overlap (`float`, *optional*):
+ Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
+ formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
+ audio_delay_seconds (`float`, *optional*, defaults to 0.0):
+ The delay in seconds to add after the audio (right padding).
+ audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
+ The silence prefix in seconds to add before the audio (left padding).
+ """
+
+ def __init__(
+ self,
+ audio_delay_seconds: Optional[float] = 0.0,
+ audio_silence_prefix_seconds: Optional[float] = 0.0,
+ **super_kwargs,
+ ):
+ super().__init__(**super_kwargs)
+ self.audio_delay_seconds = audio_delay_seconds
+ self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
+
+ def __call__(
+ self,
+ raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
+ truncation: Optional[bool] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s).
+
+ Args:
+ raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
+ (`feature_size = 2`).
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, *optional*, defaults to `False`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ """
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ if padding and truncation:
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
+ elif padding is None:
+ # by default let's pad the inputs
+ padding = True
+
+ is_batched = bool(
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
+ raw_audio = raw_audio.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_audio = [np.asarray(raw_audio).T]
+
+ # verify inputs are valid
+ for idx, example in enumerate(raw_audio):
+ if example.ndim > 2:
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
+ if self.feature_size == 1 and example.ndim != 1:
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
+ if self.feature_size == 2 and example.shape[-1] != 2:
+ raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
+
+ padded_inputs = None
+ input_values = BatchFeature({"input_values": raw_audio})
+ if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
+ if truncation:
+ max_length = min(array.shape[0] for array in raw_audio)
+ nb_step = int(np.floor(max_length / self.chunk_stride))
+ max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
+ elif padding:
+ max_length = max(array.shape[0] for array in raw_audio)
+ nb_step = int(np.ceil(max_length / self.chunk_stride))
+ max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
+ padding = "max_length"
+ else:
+ padded_inputs = input_values
+
+ # normal padding on batch
+ if padded_inputs is None:
+ padded_inputs = self.pad(
+ input_values,
+ max_length=max_length,
+ truncation=truncation,
+ padding=padding,
+ return_attention_mask=padding,
+ )
+
+ if padding:
+ padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
+
+ # now let's padd left and right
+ pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
+ pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
+ padded_inputs["input_values"] = np.pad(
+ padded_inputs["input_values"],
+ ((0, 0), (pad_left, pad_right)),
+ mode="constant",
+ constant_values=0.0,
+ )
+ if padding:
+ padded_inputs["padding_mask"] = np.pad(
+ padded_inputs["padding_mask"],
+ ((0, 0), (pad_left, pad_right)),
+ mode="constant",
+ constant_values=0,
+ )
+
+ input_values = []
+ for example in padded_inputs.pop("input_values"):
+ if self.feature_size == 1:
+ example = example[..., None]
+ input_values.append(example.T)
+
+ padded_inputs["input_values"] = input_values
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
+
+
+class KyutaiSpeechToTextPreTrainedModel(MoshiPreTrainedModel):
+ pass
+
+
+class KyutaiSpeechToTextConv1dPaddingCache(MimiConv1dPaddingCache):
+ pass
+
+
+class KyutaiSpeechToTextEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size + (config.num_codebooks * config.codebook_vocab_size) + 1,
+ config.hidden_size,
+ padding_idx=config.audio_pad_token_id,
+ )
+ audio_tokens_offsets = torch.arange(config.num_codebooks) * config.codebook_vocab_size
+ audio_tokens_offsets += config.vocab_size
+ audio_tokens_offsets = nn.functional.pad(
+ audio_tokens_offsets, (1, 0)
+ ) # pad one 0 to the left for the text token
+ self.register_buffer("audio_tokens_offsets", audio_tokens_offsets, persistent=False)
+
+ def forward(self, input_ids):
+ input_ids = torch.where(
+ input_ids == self.embed_tokens.padding_idx, input_ids, input_ids + self.audio_tokens_offsets
+ )
+ inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = inputs_embeds.sum(dim=2)
+ return inputs_embeds
+
+
+class KyutaiSpeechToTextModel(MoshiModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.embed_tokens = KyutaiSpeechToTextEmbeddings(config)
+
+
+class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
+ _keep_in_fp32_modules = ["codec_model"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.codec_model = AutoModel.from_config(config.codec_config)
+
+ # we are in an edge case where for the codec_model self.can_generate is False, setting self.codec_model.generation_config to None
+ # yet the codec_model needs a generation config to initalize it's cache for streaming inference
+ # we therefore initialize a generation config for the codec model
+ self.codec_model.generation_config = GenerationConfig.from_model_config(config.codec_config)
+
+ def forward(self, **super_kwargs):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from datasets import load_dataset, Audio
+ >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
+
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+ >>> model_id = "kyutai/stt-2.6b-en"
+
+ >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
+ >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+
+ >>> ds = load_dataset(
+ ... "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
+ ... )
+
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+ >>> inputs = processor(
+ ... ds[0]["audio"]["array"],
+ ... )
+ >>> inputs.to(torch_device)
+
+ >>> output_tokens = model.generate(**inputs)
+ >>> print(processor.batch_decode(output_tokens, skip_special_tokens=True))
+ ```"""
+ super().forward(**super_kwargs)
+
+ def _prepare_generation_config(self, *args, **kwargs):
+ generation_config, model_kwargs = GenerationMixin._prepare_generation_config(*args, **kwargs)
+ # this should be passed to the model kwargs for the input preparation
+ model_kwargs["audio_window_size"] = (
+ generation_config.audio_window_size if hasattr(generation_config, "audio_window_size") else None
+ )
+ return generation_config, model_kwargs
+
+ def _prepare_model_inputs(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[torch.Tensor] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
+ inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(
+ inputs=inputs,
+ bos_token_id=bos_token_id,
+ model_kwargs=model_kwargs,
+ )
+
+ audio_window_size = model_kwargs.get("audio_window_size", None)
+ if audio_window_size is None:
+ audio_window_size = self.codec_model.get_encoded_length(model_kwargs["input_values"].shape[-1]).item()
+ model_kwargs["audio_window_size"] = audio_window_size
+
+ batch_size = inputs.shape[0]
+ device = inputs.device
+
+ # initialize audio tokens
+ model_kwargs["audio_tokens"] = torch.zeros(
+ (batch_size, audio_window_size, self.config.num_codebooks),
+ device=device,
+ dtype=torch.long,
+ )
+ model_kwargs["current_window"] = (
+ torch.tensor([0, 0], device=device, dtype=torch.long).expand(batch_size, -1).contiguous()
+ )
+
+ # let's use generate's cache preparation to prepare the cache for the codec model
+ temporary_model_kwargs = {}
+
+ # monkey patching the codec model with cache preparation methods since we don't want it to inherit fully from GenerationMixin
+ # Add cache-related methods from GenerationMixin to codec model
+ cache_methods = [
+ "_prepare_cache_for_generation",
+ "_get_cache",
+ "_supports_default_dynamic_cache",
+ "_get_layer_device_map_for_cache_init",
+ ]
+ for method in cache_methods:
+ setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
+
+ self.codec_model._prepare_cache_for_generation(
+ generation_config=self.codec_model.generation_config,
+ model_kwargs=temporary_model_kwargs,
+ assistant_model=None,
+ batch_size=batch_size,
+ max_cache_length=self.config.codec_config.sliding_window,
+ device=device,
+ )
+
+ if "past_key_values" in temporary_model_kwargs:
+ model_kwargs["encoder_past_key_values"] = temporary_model_kwargs["past_key_values"]
+
+ # initialize the padding cache for the codec model
+ per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
+ for layer_name in self.codec_model.encoder._mimiconv1d_layer_names:
+ per_layer_padding.append(self.codec_model.encoder.get_submodule(layer_name).padding_total)
+ per_layer_padding_mode.append(self.codec_model.encoder.get_submodule(layer_name).pad_mode)
+ per_layer_in_channels.append(self.codec_model.encoder.get_submodule(layer_name).in_channels)
+
+ # downsample layer
+ per_layer_padding.append(self.codec_model.downsample.padding_total)
+ per_layer_padding_mode.append(self.codec_model.downsample.pad_mode)
+ per_layer_in_channels.append(self.codec_model.downsample.in_channels)
+
+ model_kwargs["padding_cache"] = KyutaiSpeechToTextConv1dPaddingCache(
+ num_layers=len(self.codec_model.encoder._mimiconv1d_layer_names) + 1,
+ per_layer_padding=per_layer_padding,
+ per_layer_padding_mode=per_layer_padding_mode,
+ per_layer_in_channels=per_layer_in_channels,
+ )
+
+ return inputs, input_name, model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ *args,
+ audio_tokens: Optional[torch.LongTensor] = None,
+ input_values: Optional[torch.FloatTensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ audio_window_size: Optional[int] = None,
+ current_window: Optional[tuple[int, int]] = None,
+ encoder_past_key_values: Optional[Cache] = None,
+ padding_cache: Optional[KyutaiSpeechToTextConv1dPaddingCache] = None,
+ **kwargs,
+ ):
+ model_inputs = GenerationMixin.prepare_inputs_for_generation(*args, **kwargs)
+
+ if input_values is not None:
+ cache_position = model_inputs["cache_position"]
+ start, end = current_window[0]
+
+ # first cache position is for bos token, so we need to offset by -1
+ if cache_position[-1] - 1 >= end:
+ # we need to encode the new audio tokens
+ with torch.no_grad():
+ input_values_start_idx = start * self.config.frame_size
+ input_values_end_idx = (start + audio_window_size) * self.config.frame_size
+ current_input_values = input_values[..., input_values_start_idx:input_values_end_idx]
+ codec_model_output = self.codec_model.encode(
+ current_input_values,
+ encoder_past_key_values=encoder_past_key_values,
+ padding_cache=padding_cache,
+ )
+ new_audio_tokens = codec_model_output.audio_codes.transpose(1, 2)
+
+ audio_tokens.copy_(new_audio_tokens)
+
+ start = end.clone()
+ end = end + audio_window_size
+ current_window.copy_(
+ torch.tensor([start, end], device=current_window.device).expand(current_window.shape[0], -1)
+ )
+
+ # first cache position is for bos token, so we need to offset by -1
+ current_audio_tokens_idxs = (cache_position - start - 1).clamp(min=0)
+ current_audio_tokens = audio_tokens[:, current_audio_tokens_idxs, :]
+
+ current_audio_tokens[:, cache_position == 0, :] = self.config.audio_bos_token_id
+
+ input_ids = model_inputs.pop("input_ids")
+ input_ids = torch.cat(
+ [input_ids.unsqueeze(2), current_audio_tokens],
+ dim=2,
+ )
+ model_inputs["input_ids"] = input_ids
+
+ return model_inputs
+
+ # TODO: @eustlb, this should be standardized
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ if kwargs.get("output_loading_info", False):
+ model, loading_info = PreTrainedModel.from_pretrained(*args, **kwargs)
+ else:
+ model = PreTrainedModel.from_pretrained(*args, **kwargs)
+
+ # copy depth decoder generation conf attr to the depth decoder generation config
+ prefix = "codec_"
+ prefix_len = len(prefix)
+ codec_model_attrs = {
+ attr[prefix_len:]: value
+ for attr, value in vars(model.generation_config).items()
+ if attr.startswith(prefix)
+ }
+
+ vars(model.codec_model.generation_config).update({"_from_model_config": False, **codec_model_attrs})
+
+ # remove the depth decoder generation conf attr from the model generation config
+ for attr in codec_model_attrs:
+ delattr(model.generation_config, prefix + attr)
+
+ if "output_loading_info" in kwargs:
+ return model, loading_info
+ else:
+ return model
+
+ # TODO: @eustlb, this should be standardized
+ def save_pretrained(self, *args, **kwargs):
+ prefix = "codec_"
+ codec_model_attrs = self.codec_model.generation_config.to_diff_dict()
+ codec_model_attrs.pop("transformers_version", None)
+ for attr, value in codec_model_attrs.items():
+ setattr(self.generation_config, prefix + attr, value)
+
+ PreTrainedModel.save_pretrained(self, *args, **kwargs)
+
+ def generate(self, *args, **kwargs):
+ r"""
+ This method forwards all its arguments to GenerationMixin's [`~GenerationMixin.generate`]. Please refer to the docstring of this method for more information.
+ """
+ max_new_tokens = kwargs.pop("max_new_tokens", None)
+ input_values = kwargs.get("input_values")
+
+ # TODO: @eustlb, we should have per-batch-idx values
+ # here we do not use padding_mask to be aligned to what's done in the original codebase
+ max_audio_frames = input_values.shape[-1] // self.config.codec_config.frame_size
+
+ if max_new_tokens is None or max_new_tokens > max_audio_frames:
+ if max_new_tokens is not None:
+ logger.warning(
+ f"`max_new_tokens` ({max_new_tokens}) is greater than the maximum number of audio frames ({max_audio_frames})."
+ f"Setting `max_new_tokens` to {max_audio_frames}."
+ )
+ max_new_tokens = max_audio_frames
+
+ return GenerationMixin.generate(
+ *args,
+ max_new_tokens=max_new_tokens,
+ **kwargs,
+ )
+
+
+__all__ = [
+ "KyutaiSpeechToTextPreTrainedModel",
+ "KyutaiSpeechToTextModel",
+ "KyutaiSpeechToTextForConditionalGeneration",
+ "KyutaiSpeechToTextFeatureExtractor",
+]
diff --git a/src/transformers/models/stt/processing_kyutai_speech_to_text.py b/src/transformers/models/stt/processing_kyutai_speech_to_text.py
new file mode 100644
index 0000000000..0b3a021712
--- /dev/null
+++ b/src/transformers/models/stt/processing_kyutai_speech_to_text.py
@@ -0,0 +1,104 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+
+
+class KyutaiSpeechToTextProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "audio_kwargs": {
+ "sampling_rate": 24000,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class KyutaiSpeechToTextProcessor(ProcessorMixin):
+ r"""
+ Constructs a Moshi ASR processor which wraps [`EncodecFeatureExtractor`] and
+ [`PreTrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
+ tokenizer functionalities. See the [`~KyutaiSpeechToTextProcessor.__call__`] for more
+ information.
+ """
+
+ feature_extractor_class = "KyutaiSpeechToTextFeatureExtractor"
+ tokenizer_class = "PreTrainedTokenizerFast"
+
+ def __call__(
+ self,
+ audio: Optional[AudioInput] = None,
+ **kwargs: Unpack[KyutaiSpeechToTextProcessorKwargs],
+ ):
+ r"""
+ Main method to prepare audio to be fed as input to the model. This method forwards the `audio`
+ arguments to KyutaiSpeechToTextFeatureExtractor's [`~KyutaiSpeechToTextFeatureExtractor.__call__`]. Please refer
+ to the docstring of the above method for more information.
+
+ Args:
+ audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
+ The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
+ tensor.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
+ - **padding_mask** -- List of indices specifying which input values should be ignored by the model.
+ """
+
+ if audio is None:
+ raise ValueError("`audio` is required.")
+
+ output_kwargs = self._merge_kwargs(
+ KyutaiSpeechToTextProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+
+ # ensure audio in correct format
+ audio = make_list_of_audio(audio)
+
+ inputs = self.feature_extractor(
+ audio,
+ **audio_kwargs,
+ )
+
+ return inputs
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to KyutaiSpeechToTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+
+__all__ = ["KyutaiSpeechToTextProcessor"]
diff --git a/tests/models/kyutai_speech_to_text/__init__.py b/tests/models/kyutai_speech_to_text/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py
new file mode 100644
index 0000000000..a6e08f714f
--- /dev/null
+++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py
@@ -0,0 +1,704 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Moshi ASR model."""
+
+import gc
+import inspect
+import tempfile
+import unittest
+
+import datasets
+import pytest
+from parameterized import parameterized
+
+from transformers import (
+ KyutaiSpeechToTextConfig,
+ KyutaiSpeechToTextForConditionalGeneration,
+ KyutaiSpeechToTextProcessor,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ cleanup,
+ require_torch,
+ require_torch_accelerator,
+ require_torch_sdpa,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+)
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ KyutaiSpeechToTextForConditionalGeneration,
+ KyutaiSpeechToTextModel,
+ )
+
+
+class KyutaiSpeechToTextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ text_seq_length=1,
+ input_values_length=192, # gives 3 audio tokens, corresponding to the default in GenerationTesterMixin
+ is_training=False,
+ use_input_mask=True,
+ use_token_type_ids=False,
+ use_labels=True,
+ codebook_vocab_size=2049,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=None,
+ max_position_embeddings=512,
+ rope_theta=10000.0,
+ hidden_act="silu",
+ head_dim=None,
+ initializer_range=0.02,
+ use_cache=True,
+ sliding_window=512,
+ attention_dropout=0.1,
+ ffn_dim=38,
+ rms_norm_eps=1e-6,
+ num_codebooks=8,
+ frame_size=64,
+ delay_in_tokens=5,
+ audio_bos_token_id=2048,
+ audio_pad_token_id=2048,
+ tie_word_embeddings=False,
+ pad_token_id=0,
+ bos_token_id=1,
+ codec_config={
+ "model_type": "mimi",
+ "num_quantizers": 8,
+ "audio_channels": 1,
+ "chunk_in_sec": None,
+ "hidden_size": 16,
+ "num_filters": 8,
+ "num_residual_layers": 1,
+ "upsampling_ratios": [8, 4],
+ "codebook_size": 16,
+ "vector_quantization_hidden_dimension": 16,
+ "upsample_groups": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "sliding_window": 4,
+ "codebook_dim": 16,
+ "use_cache": False,
+ },
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.text_seq_length = text_seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.codebook_vocab_size = codebook_vocab_size
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.hidden_act = hidden_act
+ self.head_dim = head_dim
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.sliding_window = sliding_window
+ self.attention_dropout = attention_dropout
+ self.ffn_dim = ffn_dim
+ self.rms_norm_eps = rms_norm_eps
+ self.num_codebooks = num_codebooks
+ self.frame_size = frame_size
+ self.delay_in_tokens = delay_in_tokens
+ self.audio_bos_token_id = audio_bos_token_id
+ self.audio_pad_token_id = audio_pad_token_id
+ self.tie_word_embeddings = tie_word_embeddings
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.codec_config = codec_config
+ self.scope = scope
+ self.input_values_length = input_values_length
+
+ def get_config(self):
+ return KyutaiSpeechToTextConfig(
+ codebook_vocab_size=self.codebook_vocab_size,
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=self.num_key_value_heads,
+ max_position_embeddings=self.max_position_embeddings,
+ rope_theta=self.rope_theta,
+ hidden_act=self.hidden_act,
+ head_dim=self.head_dim,
+ initializer_range=self.initializer_range,
+ use_cache=self.use_cache,
+ sliding_window=self.sliding_window,
+ attention_dropout=self.attention_dropout,
+ ffn_dim=self.ffn_dim,
+ rms_norm_eps=self.rms_norm_eps,
+ num_codebooks=self.num_codebooks,
+ frame_size=self.frame_size,
+ delay_in_tokens=self.delay_in_tokens,
+ audio_bos_token_id=self.audio_bos_token_id,
+ audio_pad_token_id=self.audio_pad_token_id,
+ tie_word_embeddings=self.tie_word_embeddings,
+ pad_token_id=self.pad_token_id,
+ bos_token_id=self.bos_token_id,
+ codec_config=self.codec_config,
+ )
+
+ def create_and_check_model(self, config, input_ids, input_mask):
+ model = KyutaiSpeechToTextModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+
+ text_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) + 1
+ codebook_input_ids = (
+ ids_tensor([self.batch_size, self.seq_length, self.num_codebooks], self.codebook_vocab_size - 1) + 1
+ )
+
+ input_ids = torch.cat([text_input_ids.unsqueeze(2), codebook_input_ids], dim=2)
+ attention_mask = text_input_ids.ne(1).to(torch_device)
+
+ return config, input_ids, attention_mask
+
+ def prepare_config_and_inputs_generate(self):
+ config = self.get_config()
+
+ input_ids = torch.ones([self.batch_size, 1], dtype=torch.long, device=torch_device)
+ input_values = floats_tensor([self.batch_size, 1, self.input_values_length])
+ padding_mask = torch.ones_like(input_values, dtype=torch.int32, device=torch_device)
+
+ return config, input_ids, input_values, padding_mask
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ attention_mask,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_common_generate(self):
+ config_and_inputs = self.prepare_config_and_inputs_generate()
+ (
+ config,
+ input_ids,
+ input_values,
+ padding_mask,
+ ) = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "input_values": input_values,
+ "padding_mask": padding_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (
+ (
+ KyutaiSpeechToTextModel,
+ KyutaiSpeechToTextForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": KyutaiSpeechToTextModel,
+ "automatic-speech-recognition": KyutaiSpeechToTextForConditionalGeneration,
+ }
+ if is_torch_available()
+ else {}
+ )
+ test_headmasking = False
+ test_pruning = False
+ fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
+
+ # Need to use `0.8` instead of `0.9` for `test_cpu_offload`
+ # This is because we are hitting edge cases with the causal_mask buffer
+ model_split_percents = [0.5, 0.7, 0.8]
+
+ def setUp(self):
+ self.model_tester = KyutaiSpeechToTextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=KyutaiSpeechToTextConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels)
+
+ return inputs_dict
+
+ def prepare_config_and_inputs_for_generate(self, batch_size=2):
+ # monkey patch prepare_config_and_inputs_for_common
+
+ prepare_config_and_inputs_for_common = self.model_tester.prepare_config_and_inputs_for_common
+ original_batch_size = self.model_tester.batch_size
+
+ self.model_tester.prepare_config_and_inputs_for_common = (
+ self.model_tester.prepare_config_and_inputs_for_common_generate
+ )
+ self.model_tester.batch_size = batch_size
+
+ config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
+ self.model_tester.prepare_config_and_inputs_for_common = prepare_config_and_inputs_for_common
+
+ self.model_tester.batch_size = original_batch_size
+ return config, filtered_inputs_dict
+
+ @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
+ def test_tie_model_weights(self):
+ pass
+
+ @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
+ def test_resize_embeddings_untied(self):
+ pass
+
+ @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ @pytest.mark.skip(reason="Moshi ASR has custom embedding approach (text and audio embeddings).")
+ def test_tied_weights_keys(self):
+ pass
+
+ @pytest.mark.skip(reason="Does not apply to Moshi ASR that requires input_values.")
+ def test_generate_without_input_ids(self):
+ pass
+
+ def test_initialization(self):
+ """
+ Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
+ See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = ["conv", "input_proj", "output_proj"]
+ if param.requires_grad:
+ if any(x in name for x in uniform_init_parms):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
+ @require_torch_sdpa
+ def test_eager_matches_sdpa_inference(
+ self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
+ ):
+ if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
+ self.skipTest("Test is failing, fix me :) ")
+ parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
+ parent_parameterized_test(self)
+
+ @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
+ def test_cpu_offload(self):
+ pass
+
+ @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
+ def test_disk_offload_bin(self):
+ pass
+
+ @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
+ def test_disk_offload_safetensors(self):
+ pass
+
+ @pytest.mark.generate
+ def test_left_padding_compatibility(self):
+ # NOTE: left-padding results in small numerical differences. This is expected.
+ # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
+
+ # First, filter out models that don't support left padding
+ # - The model must have generative capabilities
+ if len(self.all_generative_model_classes) == 0:
+ self.skipTest(reason="No generative architecture available for this model.")
+
+ # - The model must support padding
+ if not self.has_attentions:
+ self.skipTest(reason="This model doesn't support padding.")
+
+ # - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
+ decoder_only_classes = []
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.prepare_config_and_inputs_for_generate()
+ if config.is_encoder_decoder:
+ continue
+ else:
+ decoder_only_classes.append(model_class)
+ if len(decoder_only_classes) == 0:
+ self.skipTest(reason="No decoder-only architecture available for this model.")
+
+ # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
+ # added support for it yet. We skip these models for now.
+ has_encoder_attributes = any(
+ attr_name
+ for attr_name in config.to_dict().keys()
+ if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
+ )
+ if has_encoder_attributes:
+ self.skipTest(
+ reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
+ )
+
+ # Then, test left-padding
+ def _prepare_model_kwargs(input_ids, attention_mask, signature):
+ model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
+ if "position_ids" in signature:
+ position_ids = torch.cumsum(attention_mask, dim=-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ model_kwargs["position_ids"] = position_ids
+ if "cache_position" in signature:
+ cache_position = torch.arange(input_ids.shape[1], device=torch_device)
+ model_kwargs["cache_position"] = cache_position
+ return model_kwargs
+
+ for model_class in decoder_only_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict.get("attention_mask")
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+
+ model = model_class(config).to(torch_device).eval()
+ signature = inspect.signature(model.forward).parameters.keys()
+
+ # no cache as some models require special cache classes to be init outside forward
+ model.generation_config.use_cache = False
+
+ # Without padding
+ model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
+ next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # With left-padding (length 32)
+ # can hardcode pad_token to be 0 as we'll do attn masking anyway
+ pad_token_id = (
+ config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
+ )
+ pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
+ padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
+ padded_input_ids = torch.cat((padding, input_ids), dim=1)
+ padded_attention_mask = torch.cat(
+ (torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
+ )
+ model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
+ next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # They should result in very similar logits
+ torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
+
+ def test_generate_continue_from_past_key_values(self):
+ # Tests that we can continue generating from past key values, returned from a previous `generate` call
+ for model_class in self.all_generative_model_classes:
+ if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
+ self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
+ if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
+ self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
+
+ config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
+
+ if not hasattr(config.get_text_config(), "use_cache"):
+ self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
+
+ # Let's make it always:
+ # 1. use cache (for obvious reasons)
+ # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
+ # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
+ # continuation would force it to generate beyond an EOS token)
+ # 3. ignore `token_type_ids` for simplicity
+ # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
+ # active by default on some models
+ # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
+ # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
+ # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
+ # with cache, what is considered a prompt is different in the two cases.
+
+ if "token_type_ids" in inputs:
+ del inputs["token_type_ids"]
+
+ model = model_class(config).to(torch_device)
+ model.eval()
+
+ # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
+ outputs = model(**inputs)
+ if "past_key_values" not in outputs:
+ self.skipTest(reason="This model doesn't return `past_key_values`")
+
+ generate_kwargs = {
+ "pad_token_id": -1,
+ "eos_token_id": -1,
+ "forced_eos_token_id": None,
+ "encoder_no_repeat_ngram_size": 0,
+ "use_cache": True,
+ "do_sample": False,
+ "return_dict_in_generate": True,
+ "output_scores": True,
+ }
+
+ # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
+ _, inputs = self.prepare_config_and_inputs_for_generate()
+ outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
+
+ # Let's generate again, but passing the past key values in between (2 + 1 = 3 tokens). Note that the
+ # inputs may need to be tweaked across `generate` calls (like the attention mask).
+ outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=2)
+
+ # Continue from the tokens generated above, preparing the inputs accordingly
+ inputs["past_key_values"] = outputs_cached.past_key_values
+ new_attention_len = outputs_cached.sequences.shape[-1]
+ if config.is_encoder_decoder:
+ inputs["decoder_input_ids"] = outputs_cached.sequences
+ if "decoder_attention_mask" in inputs:
+ inputs["decoder_attention_mask"] = torch.nn.functional.pad(
+ inputs["decoder_attention_mask"],
+ (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
+ mode="constant",
+ value=1,
+ )
+ else:
+ inputs["input_ids"] = outputs_cached.sequences
+ if "attention_mask" in inputs:
+ inputs["attention_mask"] = torch.nn.functional.pad(
+ inputs["attention_mask"],
+ (0, new_attention_len - inputs["attention_mask"].shape[1]),
+ mode="constant",
+ value=1,
+ )
+ first_caches_scores = outputs_cached.scores
+ outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
+ full_cached_scores = first_caches_scores + outputs_cached.scores
+ outputs_cached.scores = full_cached_scores
+
+ # The two sets of generated text and past kv should be equal to each other
+ self._check_similar_generate_outputs(outputs, outputs_cached)
+ for layer_idx in range(len(outputs_cached.past_key_values)):
+ for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
+ self.assertTrue(
+ torch.allclose(
+ outputs.past_key_values[layer_idx][kv_idx],
+ outputs_cached.past_key_values[layer_idx][kv_idx],
+ )
+ )
+
+ # needs to be overridden to avoid to avoid casting of input_values to float16
+ # indeed, the codec model is kept in fp32, so we need to avoid casting input_values to float16
+ def _test_attention_implementation(self, attn_implementation):
+ """
+ Compares the output of generate with the eager attention implementation against other implementations.
+ NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
+ this separate function.
+ """
+ max_new_tokens = 30
+ support_flag = {
+ "sdpa": "_supports_sdpa",
+ "flash_attention_2": "_supports_flash_attn_2",
+ }
+
+ for model_class in self.all_generative_model_classes:
+ if not getattr(model_class, support_flag[attn_implementation]):
+ self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
+
+ config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
+ inputs_dict = {}
+ for input_name, input_data in original_inputs_dict.items():
+ if (
+ isinstance(input_data, torch.Tensor)
+ and input_data.dtype in [torch.float32, torch.bfloat16]
+ and input_name != "input_values"
+ ):
+ inputs_dict[input_name] = input_data.to(torch.float16)
+ else:
+ inputs_dict[input_name] = input_data
+ main_input = inputs_dict[model_class.main_input_name]
+
+ # FA2 doesn't accept masking in the middle of the sequence for now. We usually generate right-padded
+ # attention masks at test time and, with generate, the mask will be appended with 1s on the right,
+ # resulting in a mask with holes (not supported properly by FA2).
+ if attn_implementation == "flash_attention_2":
+ for input_name in ("attention_mask", "decoder_attention_mask", "encoder_attention_mask"):
+ if input_name in inputs_dict:
+ inputs_dict[input_name] = torch.ones_like(inputs_dict[input_name])
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ del model
+ gc.collect()
+
+ generate_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": False,
+ "return_dict_in_generate": True,
+ "output_scores": True,
+ "use_cache": True,
+ }
+
+ model_eager = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="eager",
+ ).to(torch_device)
+ res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
+ del model_eager
+ gc.collect()
+
+ model_attn = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation=attn_implementation,
+ ).to(torch_device)
+ res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
+ del model_attn
+ gc.collect()
+
+ self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
+
+
+class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
+ _dataset = None
+
+ def setUp(self):
+ self.model_checkpoint = "kyutai/stt-2.6b-en"
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ @classmethod
+ def _load_dataset(cls):
+ # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
+ if cls._dataset is None:
+ cls._dataset = datasets.load_dataset(
+ "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
+ )
+ # using 24000 here for simplicity, should rather be processor.feature_extractor.sampling_rate
+ cls._dataset = cls._dataset.cast_column("audio", datasets.Audio(sampling_rate=24000))
+
+ def _load_datasamples(self, num_samples):
+ self._load_dataset()
+ ds = self._dataset
+ speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
+ return [x["array"] for x in speech_samples]
+
+ @slow
+ @require_torch_accelerator
+ def test_generation(self):
+ """
+ reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/7a9aa6139d11e0103c6b65bac103da52
+
+ DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
+ as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
+ ultimately giving different outputs.
+ """
+ processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
+ model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
+ self.model_checkpoint, device_map=torch_device
+ )
+
+ samples = self._load_datasamples(1)
+ inputs = processor(
+ samples,
+ ).to(torch_device)
+
+ out = model.generate(**inputs)
+
+ # fmt: off
+ EXPECTED_TOKENS = torch.tensor([
+ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
+ )
+ # fmt: on
+
+ torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)
+
+ @slow
+ @require_torch_accelerator
+ def test_generation_batched(self):
+ """
+ reproduce test expected outputs using original codebase: https://gist.github.com/eustlb/b58c217c75124d405ec1c13877c7ece8
+
+ DISCLAIMER: we are testing for pretty short inputs. Indeed, reproducing correct expected outputs for longer is not possible
+ as implementation choices (qkv matrix in one linear for original code vs three for hf) create growing divergence with context lenght,
+ ultimately giving different outputs.
+ """
+ processor = KyutaiSpeechToTextProcessor.from_pretrained(self.model_checkpoint)
+ model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
+ self.model_checkpoint, device_map=torch_device
+ )
+
+ samples = self._load_datasamples(4)
+ inputs = processor(
+ samples,
+ ).to(torch_device)
+
+ out = model.generate(**inputs)
+
+ # fmt: off
+ EXPECTED_TOKENS = torch.tensor([
+ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 0, 277, 3, 0, 265, 0, 267, 1162, 261, 274, 410, 0, 272, 3, 0, 265, 0, 260, 1621, 0, 1174, 371, 262, 3, 3, 3, 0, 269, 0, 281, 0, 304, 0, 2433, 3, 0, 266, 3, 0, 281, 1661, 3, 0, 376, 3, 3, 0, 350, 261, 401, 516, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
+ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 500, 334, 0, 277, 3, 0, 1519, 263, 3, 3, 0, 3635, 428, 641, 264, 261, 0, 511, 1109, 3, 0, 1138, 3, 3, 3, 0, 508, 827, 3, 3, 3, 3, 0, 468, 3, 3, 0, 376, 3, 3, 3, 0, 260, 978, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
+ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 527, 261, 3, 0, 409, 3, 3, 3, 0, 271, 3, 0, 309, 3, 0, 285, 3, 0, 521, 371, 609, 3, 3, 0, 260, 959, 3, 3, 3, 0, 272, 3, 0, 265, 0, 546, 262, 3, 3, 3, 3, 3, 3, 0, 291, 3, 0, 975, 2203, 3, 3, 3, 3, 0, 269, 3, 0, 260, 489, 651, 274, 279, 1870, 3, 0, 1084, 873, 273, 3, 0, 260, 531, 3, 3, 0, 409, 262, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 1502, 1005, 836, 3, 3, 0, 1666, 306, 3, 0, 340, 3, 0, 260, 3232, 3, 0, 269, 3, 3, 0, 275, 261, 0, 260, 1379, 261, 0, 3324, 3, 3, 3, 3, 0, 549, 3, 3, 0, 693, 405, 323, 3, 0, 266, 3, 3, 0, 265, 0, 699, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
+ [48000, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 414, 0, 392, 3, 3, 0, 1269, 314, 0, 2607, 261, 3, 3, 3, 0, 1098, 295, 3, 3, 3, 0, 446, 625, 3, 0, 496, 280, 1205, 485, 1071, 1627, 449, 264, 261, 3, 0, 400, 0, 277, 3, 3, 3, 0, 260, 342, 3, 0, 618, 280, 1866, 3, 3, 0, 554, 3, 3, 3, 3, 0, 317, 262, 3, 3, 3, 3, 3, 3, 3, 3, 0, 269, 0, 303, 3, 0, 573, 2615, 3, 3, 0, 276, 3, 0, 275, 0, 305, 3, 0, 260, 415, 3, 3, 0, 272, 3, 3, 3, 3, 0, 1631, 327, 3, 3, 0, 333, 739, 841, 263, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
+ ])
+ # fmt: on
+
+ torch.testing.assert_close(out.cpu(), EXPECTED_TOKENS)
diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py
index bf48f34ce1..d9b0216b15 100644
--- a/tests/models/mimi/test_modeling_mimi.py
+++ b/tests/models/mimi/test_modeling_mimi.py
@@ -107,14 +107,21 @@ class MimiModelTester:
self.sliding_window = sliding_window
self.use_cache = use_cache
- def prepare_config_and_inputs(self):
- input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
+ def prepare_config_and_inputs(self, input_values_length=None):
+ input_values = floats_tensor(
+ [
+ self.batch_size,
+ self.num_channels,
+ self.intermediate_size if input_values_length is None else input_values_length,
+ ],
+ scale=1.0,
+ )
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
- def prepare_config_and_inputs_for_common(self):
- config, inputs_dict = self.prepare_config_and_inputs()
+ def prepare_config_and_inputs_for_common(self, input_values_length=None):
+ config, inputs_dict = self.prepare_config_and_inputs(input_values_length=input_values_length)
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
@@ -508,6 +515,54 @@ class MimiIntegrationTest(unittest.TestCase):
)
self.assertTrue(rmse < 1e-3)
+ def test_integration_encode_with_padding_cache(self):
+ """
+ We test here the possibility to run Mimi in a streaming manner, i.e. chunk by chunk.
+ 1. we encode a first time the entire audio
+ 2. we encode the audio chunk by chunk, each chunk being the smallest size possible for the model (i.e. the frame size)
+
+ This test must be run on CPU since GPU floating point operations accumulate rounding errors that cause test failures.
+ """
+ librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+ model_id = "kyutai/mimi"
+
+ model = MimiModel.from_pretrained(model_id, use_cache=True).to("cpu")
+ processor = AutoFeatureExtractor.from_pretrained(model_id)
+
+ librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
+ audio_sample = librispeech_dummy[-1]["audio"]["array"]
+
+ inputs = processor(
+ raw_audio=audio_sample,
+ sampling_rate=processor.sampling_rate,
+ return_tensors="pt",
+ ).to("cpu")
+
+ frame_size = model.config.frame_size
+ audio_codes = model.encode(inputs["input_values"]).audio_codes
+
+ # streaming chunk by chunk
+ encoder_past_key_values = None
+ padding_cache = None
+ encoded_frames_list = []
+
+ for start in range(0, inputs["input_values"].shape[-1], frame_size):
+ input_values_chunk = inputs["input_values"][:, :, start : start + frame_size]
+ encoder_outputs = model.encode(
+ input_values_chunk,
+ padding_cache=padding_cache,
+ encoder_past_key_values=encoder_past_key_values,
+ use_streaming=True,
+ )
+ encoder_past_key_values = encoder_outputs.encoder_past_key_values
+ padding_cache = encoder_outputs.padding_cache
+ encoded_frames_list.append(encoder_outputs.audio_codes)
+
+ streamed_audio_codes = torch.cat(encoded_frames_list, dim=-1)
+
+ torch.testing.assert_close(streamed_audio_codes, audio_codes)
+
def test_integration(self):
expected_rmses = {
"8": 0.0018785292,
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index c482cd43d4..f404f99628 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -3566,7 +3566,11 @@ class ModelTesterMixin:
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
- if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
+ if (
+ model.__class__.__name__.startswith("Musicgen")
+ and hasattr(self.model_tester, "num_codebooks")
+ and not hasattr(model_eager, "text_encoder")
+ ):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
@@ -3626,7 +3630,7 @@ class ModelTesterMixin:
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
- if hasattr(self.model_tester, "num_codebooks"):
+ if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py
index 7630dc2387..a930e63e99 100644
--- a/utils/modular_model_converter.py
+++ b/utils/modular_model_converter.py
@@ -619,7 +619,7 @@ ALL_FILE_TYPES = (
"processing",
"image_processing",
"video_processing",
- "feature_extractor",
+ "feature_extraction",
)
@@ -1137,7 +1137,7 @@ TYPE_TO_FILE_TYPE = {
"VideoProcessor": "video_processing",
"VideoProcessorInitKwargs": "video_processing",
"FastImageProcessorKwargs": "image_processing*_fast",
- "FeatureExtractor": "feature_extractor",
+ "FeatureExtractor": "feature_extraction",
"ProcessorKwargs": "processing",
"VideosKwargs": "processing",
"ImagesKwargs": "processing",