[Kyutai-STT] correct model type + model id (#39035)
* correct model type + model id * udpate doc * init fix * style !!!
This commit is contained in:
@@ -847,7 +847,7 @@
|
||||
title: GraniteSpeech
|
||||
- local: model_doc/hubert
|
||||
title: Hubert
|
||||
- local: model_doc/stt
|
||||
- local: model_doc/kyutai_speech_to_text
|
||||
title: Kyutai Speech-To-Text
|
||||
- local: model_doc/mctct
|
||||
title: MCTCT
|
||||
|
||||
@@ -36,10 +36,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
|
||||
|
||||
# 1. load the model and the processor
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = "kyutai/stt-2.6b-en"
|
||||
model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
|
||||
|
||||
# 2. load audio samples
|
||||
ds = load_dataset(
|
||||
@@ -69,10 +69,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
|
||||
|
||||
# 1. load the model and the processor
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = "kyutai/stt-2.6b-en"
|
||||
model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
|
||||
|
||||
# 2. load audio samples
|
||||
ds = load_dataset(
|
||||
@@ -158,6 +158,7 @@ if TYPE_CHECKING:
|
||||
from .janus import *
|
||||
from .jetmoe import *
|
||||
from .kosmos2 import *
|
||||
from .kyutai_speech_to_text import *
|
||||
from .layoutlm import *
|
||||
from .layoutlmv2 import *
|
||||
from .layoutlmv3 import *
|
||||
@@ -286,7 +287,6 @@ if TYPE_CHECKING:
|
||||
from .squeezebert import *
|
||||
from .stablelm import *
|
||||
from .starcoder2 import *
|
||||
from .stt import *
|
||||
from .superglue import *
|
||||
from .superpoint import *
|
||||
from .swiftformer import *
|
||||
|
||||
@@ -184,6 +184,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("jetmoe", "JetMoeConfig"),
|
||||
("jukebox", "JukeboxConfig"),
|
||||
("kosmos-2", "Kosmos2Config"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("layoutlmv3", "LayoutLMv3Config"),
|
||||
@@ -326,7 +327,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("stablelm", "StableLmConfig"),
|
||||
("starcoder2", "Starcoder2Config"),
|
||||
("stt", "KyutaiSpeechToTextConfig"),
|
||||
("superglue", "SuperGlueConfig"),
|
||||
("superpoint", "SuperPointConfig"),
|
||||
("swiftformer", "SwiftFormerConfig"),
|
||||
@@ -562,6 +562,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("jetmoe", "JetMoe"),
|
||||
("jukebox", "Jukebox"),
|
||||
("kosmos-2", "KOSMOS-2"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToText"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("layoutlmv2", "LayoutLMv2"),
|
||||
("layoutlmv3", "LayoutLMv3"),
|
||||
@@ -717,7 +718,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("squeezebert", "SqueezeBERT"),
|
||||
("stablelm", "StableLm"),
|
||||
("starcoder2", "Starcoder2"),
|
||||
("stt", "KyutaiSpeechToText"),
|
||||
("superglue", "SuperGlue"),
|
||||
("superpoint", "SuperPoint"),
|
||||
("swiftformer", "SwiftFormer"),
|
||||
|
||||
@@ -65,6 +65,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("groupvit", "CLIPFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("imagegpt", "ImageGPTFeatureExtractor"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
||||
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
||||
("levit", "LevitFeatureExtractor"),
|
||||
@@ -91,7 +92,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("sew-d", "Wav2Vec2FeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("speecht5", "SpeechT5FeatureExtractor"),
|
||||
("stt", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("swiftformer", "ViTFeatureExtractor"),
|
||||
("swin", "ViTFeatureExtractor"),
|
||||
("swinv2", "ViTFeatureExtractor"),
|
||||
|
||||
@@ -174,6 +174,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("jetmoe", "JetMoeModel"),
|
||||
("jukebox", "JukeboxModel"),
|
||||
("kosmos-2", "Kosmos2Model"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
|
||||
("layoutlm", "LayoutLMModel"),
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("layoutlmv3", "LayoutLMv3Model"),
|
||||
@@ -304,7 +305,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("stablelm", "StableLmModel"),
|
||||
("starcoder2", "Starcoder2Model"),
|
||||
("stt", "KyutaiSpeechToTextModel"),
|
||||
("superglue", "SuperGlueForKeypointMatching"),
|
||||
("swiftformer", "SwiftFormerModel"),
|
||||
("swin", "SwinModel"),
|
||||
@@ -1060,6 +1060,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||
@@ -1067,7 +1068,6 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
("speecht5", "SpeechT5ForSpeechToText"),
|
||||
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||
("whisper", "WhisperForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -80,6 +80,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("internvl", "InternVLProcessor"),
|
||||
("janus", "JanusProcessor"),
|
||||
("kosmos-2", "Kosmos2Processor"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
("llama4", "Llama4Processor"),
|
||||
@@ -117,7 +118,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("speecht5", "SpeechT5Processor"),
|
||||
("stt", "KyutaiSpeechToTextProcessor"),
|
||||
("trocr", "TrOCRProcessor"),
|
||||
("tvlt", "TvltProcessor"),
|
||||
("tvp", "TvpProcessor"),
|
||||
|
||||
@@ -28,7 +28,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
||||
2.6b-en model.
|
||||
|
||||
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
|
||||
e.g. [kyutai/stt-2.6b-en-trfs](https://huggingface.co/kyutai/stt-2.6b-en-trfs)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
@@ -110,8 +110,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
|
||||
model_type = "stt"
|
||||
model_type = "kyutai_speech_to_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {"codec_config": AutoConfig}
|
||||
|
||||
@@ -190,7 +190,14 @@ def write_model(
|
||||
print("Converting the model.")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
config = KyutaiSpeechToTextConfig()
|
||||
config = KyutaiSpeechToTextConfig(
|
||||
vocab_size=8001,
|
||||
max_position_embeddings=375,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=128,
|
||||
)
|
||||
config.use_cache = True
|
||||
config.codec_config.sliding_window = 250
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
||||
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
||||
@@ -1,5 +1,5 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
||||
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
||||
@@ -713,7 +713,7 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention):
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
STT_ATTENTION_CLASSES = {
|
||||
KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = {
|
||||
"eager": KyutaiSpeechToTextAttention,
|
||||
"flash_attention_2": KyutaiSpeechToTextFlashAttention2,
|
||||
"sdpa": KyutaiSpeechToTextSdpaAttention,
|
||||
@@ -726,7 +726,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.use_flexible_linear = use_flexible_linear
|
||||
|
||||
self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
|
||||
)
|
||||
|
||||
@@ -1169,7 +1169,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model_id = "kyutai/stt-2.6b-en"
|
||||
>>> model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
@@ -278,7 +278,7 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
|
||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model_id = "kyutai/stt-2.6b-en"
|
||||
>>> model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
@@ -619,7 +619,7 @@ class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCa
|
||||
_dataset = None
|
||||
|
||||
def setUp(self):
|
||||
self.model_checkpoint = "kyutai/stt-2.6b-en"
|
||||
self.model_checkpoint = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
Reference in New Issue
Block a user