[Kyutai-STT] correct model type + model id (#39035)

* correct model type + model id

* udpate doc

* init fix

* style !!!
This commit is contained in:
eustlb
2025-06-25 18:09:00 +02:00
committed by GitHub
parent dad0e87c79
commit 551e48f182
15 changed files with 29 additions and 23 deletions

View File

@@ -847,7 +847,7 @@
title: GraniteSpeech
- local: model_doc/hubert
title: Hubert
- local: model_doc/stt
- local: model_doc/kyutai_speech_to_text
title: Kyutai Speech-To-Text
- local: model_doc/mctct
title: MCTCT

View File

@@ -36,10 +36,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
# 1. load the model and the processor
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en"
model_id = "kyutai/stt-2.6b-en-trfs"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
# 2. load audio samples
ds = load_dataset(
@@ -69,10 +69,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
# 1. load the model and the processor
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en"
model_id = "kyutai/stt-2.6b-en-trfs"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
# 2. load audio samples
ds = load_dataset(

View File

@@ -158,6 +158,7 @@ if TYPE_CHECKING:
from .janus import *
from .jetmoe import *
from .kosmos2 import *
from .kyutai_speech_to_text import *
from .layoutlm import *
from .layoutlmv2 import *
from .layoutlmv3 import *
@@ -286,7 +287,6 @@ if TYPE_CHECKING:
from .squeezebert import *
from .stablelm import *
from .starcoder2 import *
from .stt import *
from .superglue import *
from .superpoint import *
from .swiftformer import *

View File

@@ -184,6 +184,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("jetmoe", "JetMoeConfig"),
("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"),
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
("layoutlm", "LayoutLMConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("layoutlmv3", "LayoutLMv3Config"),
@@ -326,7 +327,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("squeezebert", "SqueezeBertConfig"),
("stablelm", "StableLmConfig"),
("starcoder2", "Starcoder2Config"),
("stt", "KyutaiSpeechToTextConfig"),
("superglue", "SuperGlueConfig"),
("superpoint", "SuperPointConfig"),
("swiftformer", "SwiftFormerConfig"),
@@ -562,6 +562,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("jetmoe", "JetMoe"),
("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"),
("kyutai_speech_to_text", "KyutaiSpeechToText"),
("layoutlm", "LayoutLM"),
("layoutlmv2", "LayoutLMv2"),
("layoutlmv3", "LayoutLMv3"),
@@ -717,7 +718,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("squeezebert", "SqueezeBERT"),
("stablelm", "StableLm"),
("starcoder2", "Starcoder2"),
("stt", "KyutaiSpeechToText"),
("superglue", "SuperGlue"),
("superpoint", "SuperPoint"),
("swiftformer", "SwiftFormer"),

View File

@@ -65,6 +65,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("groupvit", "CLIPFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("imagegpt", "ImageGPTFeatureExtractor"),
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
("levit", "LevitFeatureExtractor"),
@@ -91,7 +92,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("sew-d", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("speecht5", "SpeechT5FeatureExtractor"),
("stt", "KyutaiSpeechToTextFeatureExtractor"),
("swiftformer", "ViTFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("swinv2", "ViTFeatureExtractor"),

View File

@@ -174,6 +174,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("jetmoe", "JetMoeModel"),
("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"),
("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
("layoutlm", "LayoutLMModel"),
("layoutlmv2", "LayoutLMv2Model"),
("layoutlmv3", "LayoutLMv3Model"),
@@ -304,7 +305,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"),
("stt", "KyutaiSpeechToTextModel"),
("superglue", "SuperGlueForKeypointMatching"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
@@ -1060,6 +1060,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
("moonshine", "MoonshineForConditionalGeneration"),
("pop2piano", "Pop2PianoForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForSpeechToText"),
@@ -1067,7 +1068,6 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("speecht5", "SpeechT5ForSpeechToText"),
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
("whisper", "WhisperForConditionalGeneration"),
]
)

View File

@@ -80,6 +80,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("internvl", "InternVLProcessor"),
("janus", "JanusProcessor"),
("kosmos-2", "Kosmos2Processor"),
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
("llama4", "Llama4Processor"),
@@ -117,7 +118,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("speecht5", "SpeechT5Processor"),
("stt", "KyutaiSpeechToTextProcessor"),
("trocr", "TrOCRProcessor"),
("tvlt", "TvltProcessor"),
("tvp", "TvpProcessor"),

View File

@@ -28,7 +28,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
2.6b-en model.
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
e.g. [kyutai/stt-2.6b-en-trfs](https://huggingface.co/kyutai/stt-2.6b-en-trfs)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
@@ -110,8 +110,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
>>> configuration = model.config
```"""
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
model_type = "stt"
model_type = "kyutai_speech_to_text"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {"codec_config": AutoConfig}

View File

@@ -190,7 +190,14 @@ def write_model(
print("Converting the model.")
os.makedirs(output_dir, exist_ok=True)
config = KyutaiSpeechToTextConfig()
config = KyutaiSpeechToTextConfig(
vocab_size=8001,
max_position_embeddings=375,
num_hidden_layers=16,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=128,
)
config.use_cache = True
config.codec_config.sliding_window = 250

View File

@@ -1,5 +1,5 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.

View File

@@ -1,5 +1,5 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
@@ -713,7 +713,7 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention):
return attn_output, None, past_key_value
STT_ATTENTION_CLASSES = {
KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = {
"eager": KyutaiSpeechToTextAttention,
"flash_attention_2": KyutaiSpeechToTextFlashAttention2,
"sdpa": KyutaiSpeechToTextSdpaAttention,
@@ -726,7 +726,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size
self.use_flexible_linear = use_flexible_linear
self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation](
self.self_attn = KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
)
@@ -1169,7 +1169,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model_id = "kyutai/stt-2.6b-en"
>>> model_id = "kyutai/stt-2.6b-en-trfs"
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)

View File

@@ -278,7 +278,7 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model_id = "kyutai/stt-2.6b-en"
>>> model_id = "kyutai/stt-2.6b-en-trfs"
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)

View File

@@ -619,7 +619,7 @@ class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCa
_dataset = None
def setUp(self):
self.model_checkpoint = "kyutai/stt-2.6b-en"
self.model_checkpoint = "kyutai/stt-2.6b-en-trfs"
def tearDown(self):
cleanup(torch_device, gc_collect=True)