[Kyutai-STT] correct model type + model id (#39035)

* correct model type + model id

* udpate doc

* init fix

* style !!!
This commit is contained in:
eustlb
2025-06-25 18:09:00 +02:00
committed by GitHub
parent dad0e87c79
commit 551e48f182
15 changed files with 29 additions and 23 deletions

View File

@@ -847,7 +847,7 @@
title: GraniteSpeech title: GraniteSpeech
- local: model_doc/hubert - local: model_doc/hubert
title: Hubert title: Hubert
- local: model_doc/stt - local: model_doc/kyutai_speech_to_text
title: Kyutai Speech-To-Text title: Kyutai Speech-To-Text
- local: model_doc/mctct - local: model_doc/mctct
title: MCTCT title: MCTCT

View File

@@ -36,10 +36,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
# 1. load the model and the processor # 1. load the model and the processor
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en" model_id = "kyutai/stt-2.6b-en-trfs"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
# 2. load audio samples # 2. load audio samples
ds = load_dataset( ds = load_dataset(
@@ -69,10 +69,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
# 1. load the model and the processor # 1. load the model and the processor
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "kyutai/stt-2.6b-en" model_id = "kyutai/stt-2.6b-en-trfs"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
# 2. load audio samples # 2. load audio samples
ds = load_dataset( ds = load_dataset(

View File

@@ -158,6 +158,7 @@ if TYPE_CHECKING:
from .janus import * from .janus import *
from .jetmoe import * from .jetmoe import *
from .kosmos2 import * from .kosmos2 import *
from .kyutai_speech_to_text import *
from .layoutlm import * from .layoutlm import *
from .layoutlmv2 import * from .layoutlmv2 import *
from .layoutlmv3 import * from .layoutlmv3 import *
@@ -286,7 +287,6 @@ if TYPE_CHECKING:
from .squeezebert import * from .squeezebert import *
from .stablelm import * from .stablelm import *
from .starcoder2 import * from .starcoder2 import *
from .stt import *
from .superglue import * from .superglue import *
from .superpoint import * from .superpoint import *
from .swiftformer import * from .swiftformer import *

View File

@@ -184,6 +184,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("jetmoe", "JetMoeConfig"), ("jetmoe", "JetMoeConfig"),
("jukebox", "JukeboxConfig"), ("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"), ("kosmos-2", "Kosmos2Config"),
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
("layoutlm", "LayoutLMConfig"), ("layoutlm", "LayoutLMConfig"),
("layoutlmv2", "LayoutLMv2Config"), ("layoutlmv2", "LayoutLMv2Config"),
("layoutlmv3", "LayoutLMv3Config"), ("layoutlmv3", "LayoutLMv3Config"),
@@ -326,7 +327,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("squeezebert", "SqueezeBertConfig"), ("squeezebert", "SqueezeBertConfig"),
("stablelm", "StableLmConfig"), ("stablelm", "StableLmConfig"),
("starcoder2", "Starcoder2Config"), ("starcoder2", "Starcoder2Config"),
("stt", "KyutaiSpeechToTextConfig"),
("superglue", "SuperGlueConfig"), ("superglue", "SuperGlueConfig"),
("superpoint", "SuperPointConfig"), ("superpoint", "SuperPointConfig"),
("swiftformer", "SwiftFormerConfig"), ("swiftformer", "SwiftFormerConfig"),
@@ -562,6 +562,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("jetmoe", "JetMoe"), ("jetmoe", "JetMoe"),
("jukebox", "Jukebox"), ("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"), ("kosmos-2", "KOSMOS-2"),
("kyutai_speech_to_text", "KyutaiSpeechToText"),
("layoutlm", "LayoutLM"), ("layoutlm", "LayoutLM"),
("layoutlmv2", "LayoutLMv2"), ("layoutlmv2", "LayoutLMv2"),
("layoutlmv3", "LayoutLMv3"), ("layoutlmv3", "LayoutLMv3"),
@@ -717,7 +718,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("squeezebert", "SqueezeBERT"), ("squeezebert", "SqueezeBERT"),
("stablelm", "StableLm"), ("stablelm", "StableLm"),
("starcoder2", "Starcoder2"), ("starcoder2", "Starcoder2"),
("stt", "KyutaiSpeechToText"),
("superglue", "SuperGlue"), ("superglue", "SuperGlue"),
("superpoint", "SuperPoint"), ("superpoint", "SuperPoint"),
("swiftformer", "SwiftFormer"), ("swiftformer", "SwiftFormer"),

View File

@@ -65,6 +65,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("groupvit", "CLIPFeatureExtractor"), ("groupvit", "CLIPFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"),
("imagegpt", "ImageGPTFeatureExtractor"), ("imagegpt", "ImageGPTFeatureExtractor"),
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("layoutlmv3", "LayoutLMv3FeatureExtractor"), ("layoutlmv3", "LayoutLMv3FeatureExtractor"),
("levit", "LevitFeatureExtractor"), ("levit", "LevitFeatureExtractor"),
@@ -91,7 +92,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("sew-d", "Wav2Vec2FeatureExtractor"), ("sew-d", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"), ("speech_to_text", "Speech2TextFeatureExtractor"),
("speecht5", "SpeechT5FeatureExtractor"), ("speecht5", "SpeechT5FeatureExtractor"),
("stt", "KyutaiSpeechToTextFeatureExtractor"),
("swiftformer", "ViTFeatureExtractor"), ("swiftformer", "ViTFeatureExtractor"),
("swin", "ViTFeatureExtractor"), ("swin", "ViTFeatureExtractor"),
("swinv2", "ViTFeatureExtractor"), ("swinv2", "ViTFeatureExtractor"),

View File

@@ -174,6 +174,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("jetmoe", "JetMoeModel"), ("jetmoe", "JetMoeModel"),
("jukebox", "JukeboxModel"), ("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"), ("kosmos-2", "Kosmos2Model"),
("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
("layoutlm", "LayoutLMModel"), ("layoutlm", "LayoutLMModel"),
("layoutlmv2", "LayoutLMv2Model"), ("layoutlmv2", "LayoutLMv2Model"),
("layoutlmv3", "LayoutLMv3Model"), ("layoutlmv3", "LayoutLMv3Model"),
@@ -304,7 +305,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"), ("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"), ("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"), ("starcoder2", "Starcoder2Model"),
("stt", "KyutaiSpeechToTextModel"),
("superglue", "SuperGlueForKeypointMatching"), ("superglue", "SuperGlueForKeypointMatching"),
("swiftformer", "SwiftFormerModel"), ("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"), ("swin", "SwinModel"),
@@ -1060,6 +1060,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[ [
("granite_speech", "GraniteSpeechForConditionalGeneration"), ("granite_speech", "GraniteSpeechForConditionalGeneration"),
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
("moonshine", "MoonshineForConditionalGeneration"), ("moonshine", "MoonshineForConditionalGeneration"),
("pop2piano", "Pop2PianoForConditionalGeneration"), ("pop2piano", "Pop2PianoForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForSpeechToText"), ("seamless_m4t", "SeamlessM4TForSpeechToText"),
@@ -1067,7 +1068,6 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
("speech-encoder-decoder", "SpeechEncoderDecoderModel"), ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"), ("speech_to_text", "Speech2TextForConditionalGeneration"),
("speecht5", "SpeechT5ForSpeechToText"), ("speecht5", "SpeechT5ForSpeechToText"),
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
("whisper", "WhisperForConditionalGeneration"), ("whisper", "WhisperForConditionalGeneration"),
] ]
) )

View File

@@ -80,6 +80,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("internvl", "InternVLProcessor"), ("internvl", "InternVLProcessor"),
("janus", "JanusProcessor"), ("janus", "JanusProcessor"),
("kosmos-2", "Kosmos2Processor"), ("kosmos-2", "Kosmos2Processor"),
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"), ("layoutlmv3", "LayoutLMv3Processor"),
("llama4", "Llama4Processor"), ("llama4", "Llama4Processor"),
@@ -117,7 +118,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("speech_to_text", "Speech2TextProcessor"), ("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"), ("speech_to_text_2", "Speech2Text2Processor"),
("speecht5", "SpeechT5Processor"), ("speecht5", "SpeechT5Processor"),
("stt", "KyutaiSpeechToTextProcessor"),
("trocr", "TrOCRProcessor"), ("trocr", "TrOCRProcessor"),
("tvlt", "TvltProcessor"), ("tvlt", "TvltProcessor"),
("tvp", "TvpProcessor"), ("tvp", "TvpProcessor"),

View File

@@ -28,7 +28,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
2.6b-en model. 2.6b-en model.
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en) e.g. [kyutai/stt-2.6b-en-trfs](https://huggingface.co/kyutai/stt-2.6b-en-trfs)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
@@ -110,8 +110,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify model_type = "kyutai_speech_to_text"
model_type = "stt"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {"codec_config": AutoConfig} sub_configs = {"codec_config": AutoConfig}

View File

@@ -190,7 +190,14 @@ def write_model(
print("Converting the model.") print("Converting the model.")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
config = KyutaiSpeechToTextConfig() config = KyutaiSpeechToTextConfig(
vocab_size=8001,
max_position_embeddings=375,
num_hidden_layers=16,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=128,
)
config.use_cache = True config.use_cache = True
config.codec_config.sliding_window = 250 config.codec_config.sliding_window = 250

View File

@@ -1,5 +1,5 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py. # This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. # modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.

View File

@@ -1,5 +1,5 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py. # This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. # modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
@@ -713,7 +713,7 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention):
return attn_output, None, past_key_value return attn_output, None, past_key_value
STT_ATTENTION_CLASSES = { KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = {
"eager": KyutaiSpeechToTextAttention, "eager": KyutaiSpeechToTextAttention,
"flash_attention_2": KyutaiSpeechToTextFlashAttention2, "flash_attention_2": KyutaiSpeechToTextFlashAttention2,
"sdpa": KyutaiSpeechToTextSdpaAttention, "sdpa": KyutaiSpeechToTextSdpaAttention,
@@ -726,7 +726,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.use_flexible_linear = use_flexible_linear self.use_flexible_linear = use_flexible_linear
self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation]( self.self_attn = KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
) )
@@ -1169,7 +1169,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model_id = "kyutai/stt-2.6b-en" >>> model_id = "kyutai/stt-2.6b-en-trfs"
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)

View File

@@ -278,7 +278,7 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration >>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model_id = "kyutai/stt-2.6b-en" >>> model_id = "kyutai/stt-2.6b-en-trfs"
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) >>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) >>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)

View File

@@ -619,7 +619,7 @@ class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCa
_dataset = None _dataset = None
def setUp(self): def setUp(self):
self.model_checkpoint = "kyutai/stt-2.6b-en" self.model_checkpoint = "kyutai/stt-2.6b-en-trfs"
def tearDown(self): def tearDown(self):
cleanup(torch_device, gc_collect=True) cleanup(torch_device, gc_collect=True)