[Kyutai-STT] correct model type + model id (#39035)
* correct model type + model id * udpate doc * init fix * style !!!
This commit is contained in:
@@ -847,7 +847,7 @@
|
|||||||
title: GraniteSpeech
|
title: GraniteSpeech
|
||||||
- local: model_doc/hubert
|
- local: model_doc/hubert
|
||||||
title: Hubert
|
title: Hubert
|
||||||
- local: model_doc/stt
|
- local: model_doc/kyutai_speech_to_text
|
||||||
title: Kyutai Speech-To-Text
|
title: Kyutai Speech-To-Text
|
||||||
- local: model_doc/mctct
|
- local: model_doc/mctct
|
||||||
title: MCTCT
|
title: MCTCT
|
||||||
|
|||||||
@@ -36,10 +36,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
|
|||||||
|
|
||||||
# 1. load the model and the processor
|
# 1. load the model and the processor
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model_id = "kyutai/stt-2.6b-en"
|
model_id = "kyutai/stt-2.6b-en-trfs"
|
||||||
|
|
||||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
|
||||||
|
|
||||||
# 2. load audio samples
|
# 2. load audio samples
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
@@ -69,10 +69,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
|
|||||||
|
|
||||||
# 1. load the model and the processor
|
# 1. load the model and the processor
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model_id = "kyutai/stt-2.6b-en"
|
model_id = "kyutai/stt-2.6b-en-trfs"
|
||||||
|
|
||||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
|
||||||
|
|
||||||
# 2. load audio samples
|
# 2. load audio samples
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
@@ -158,6 +158,7 @@ if TYPE_CHECKING:
|
|||||||
from .janus import *
|
from .janus import *
|
||||||
from .jetmoe import *
|
from .jetmoe import *
|
||||||
from .kosmos2 import *
|
from .kosmos2 import *
|
||||||
|
from .kyutai_speech_to_text import *
|
||||||
from .layoutlm import *
|
from .layoutlm import *
|
||||||
from .layoutlmv2 import *
|
from .layoutlmv2 import *
|
||||||
from .layoutlmv3 import *
|
from .layoutlmv3 import *
|
||||||
@@ -286,7 +287,6 @@ if TYPE_CHECKING:
|
|||||||
from .squeezebert import *
|
from .squeezebert import *
|
||||||
from .stablelm import *
|
from .stablelm import *
|
||||||
from .starcoder2 import *
|
from .starcoder2 import *
|
||||||
from .stt import *
|
|
||||||
from .superglue import *
|
from .superglue import *
|
||||||
from .superpoint import *
|
from .superpoint import *
|
||||||
from .swiftformer import *
|
from .swiftformer import *
|
||||||
|
|||||||
@@ -184,6 +184,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("jetmoe", "JetMoeConfig"),
|
("jetmoe", "JetMoeConfig"),
|
||||||
("jukebox", "JukeboxConfig"),
|
("jukebox", "JukeboxConfig"),
|
||||||
("kosmos-2", "Kosmos2Config"),
|
("kosmos-2", "Kosmos2Config"),
|
||||||
|
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
|
||||||
("layoutlm", "LayoutLMConfig"),
|
("layoutlm", "LayoutLMConfig"),
|
||||||
("layoutlmv2", "LayoutLMv2Config"),
|
("layoutlmv2", "LayoutLMv2Config"),
|
||||||
("layoutlmv3", "LayoutLMv3Config"),
|
("layoutlmv3", "LayoutLMv3Config"),
|
||||||
@@ -326,7 +327,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("squeezebert", "SqueezeBertConfig"),
|
("squeezebert", "SqueezeBertConfig"),
|
||||||
("stablelm", "StableLmConfig"),
|
("stablelm", "StableLmConfig"),
|
||||||
("starcoder2", "Starcoder2Config"),
|
("starcoder2", "Starcoder2Config"),
|
||||||
("stt", "KyutaiSpeechToTextConfig"),
|
|
||||||
("superglue", "SuperGlueConfig"),
|
("superglue", "SuperGlueConfig"),
|
||||||
("superpoint", "SuperPointConfig"),
|
("superpoint", "SuperPointConfig"),
|
||||||
("swiftformer", "SwiftFormerConfig"),
|
("swiftformer", "SwiftFormerConfig"),
|
||||||
@@ -562,6 +562,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("jetmoe", "JetMoe"),
|
("jetmoe", "JetMoe"),
|
||||||
("jukebox", "Jukebox"),
|
("jukebox", "Jukebox"),
|
||||||
("kosmos-2", "KOSMOS-2"),
|
("kosmos-2", "KOSMOS-2"),
|
||||||
|
("kyutai_speech_to_text", "KyutaiSpeechToText"),
|
||||||
("layoutlm", "LayoutLM"),
|
("layoutlm", "LayoutLM"),
|
||||||
("layoutlmv2", "LayoutLMv2"),
|
("layoutlmv2", "LayoutLMv2"),
|
||||||
("layoutlmv3", "LayoutLMv3"),
|
("layoutlmv3", "LayoutLMv3"),
|
||||||
@@ -717,7 +718,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("squeezebert", "SqueezeBERT"),
|
("squeezebert", "SqueezeBERT"),
|
||||||
("stablelm", "StableLm"),
|
("stablelm", "StableLm"),
|
||||||
("starcoder2", "Starcoder2"),
|
("starcoder2", "Starcoder2"),
|
||||||
("stt", "KyutaiSpeechToText"),
|
|
||||||
("superglue", "SuperGlue"),
|
("superglue", "SuperGlue"),
|
||||||
("superpoint", "SuperPoint"),
|
("superpoint", "SuperPoint"),
|
||||||
("swiftformer", "SwiftFormer"),
|
("swiftformer", "SwiftFormer"),
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("groupvit", "CLIPFeatureExtractor"),
|
("groupvit", "CLIPFeatureExtractor"),
|
||||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||||
("imagegpt", "ImageGPTFeatureExtractor"),
|
("imagegpt", "ImageGPTFeatureExtractor"),
|
||||||
|
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
||||||
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
||||||
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
||||||
("levit", "LevitFeatureExtractor"),
|
("levit", "LevitFeatureExtractor"),
|
||||||
@@ -91,7 +92,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("sew-d", "Wav2Vec2FeatureExtractor"),
|
("sew-d", "Wav2Vec2FeatureExtractor"),
|
||||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||||
("speecht5", "SpeechT5FeatureExtractor"),
|
("speecht5", "SpeechT5FeatureExtractor"),
|
||||||
("stt", "KyutaiSpeechToTextFeatureExtractor"),
|
|
||||||
("swiftformer", "ViTFeatureExtractor"),
|
("swiftformer", "ViTFeatureExtractor"),
|
||||||
("swin", "ViTFeatureExtractor"),
|
("swin", "ViTFeatureExtractor"),
|
||||||
("swinv2", "ViTFeatureExtractor"),
|
("swinv2", "ViTFeatureExtractor"),
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("jetmoe", "JetMoeModel"),
|
("jetmoe", "JetMoeModel"),
|
||||||
("jukebox", "JukeboxModel"),
|
("jukebox", "JukeboxModel"),
|
||||||
("kosmos-2", "Kosmos2Model"),
|
("kosmos-2", "Kosmos2Model"),
|
||||||
|
("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
|
||||||
("layoutlm", "LayoutLMModel"),
|
("layoutlm", "LayoutLMModel"),
|
||||||
("layoutlmv2", "LayoutLMv2Model"),
|
("layoutlmv2", "LayoutLMv2Model"),
|
||||||
("layoutlmv3", "LayoutLMv3Model"),
|
("layoutlmv3", "LayoutLMv3Model"),
|
||||||
@@ -304,7 +305,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("squeezebert", "SqueezeBertModel"),
|
("squeezebert", "SqueezeBertModel"),
|
||||||
("stablelm", "StableLmModel"),
|
("stablelm", "StableLmModel"),
|
||||||
("starcoder2", "Starcoder2Model"),
|
("starcoder2", "Starcoder2Model"),
|
||||||
("stt", "KyutaiSpeechToTextModel"),
|
|
||||||
("superglue", "SuperGlueForKeypointMatching"),
|
("superglue", "SuperGlueForKeypointMatching"),
|
||||||
("swiftformer", "SwiftFormerModel"),
|
("swiftformer", "SwiftFormerModel"),
|
||||||
("swin", "SwinModel"),
|
("swin", "SwinModel"),
|
||||||
@@ -1060,6 +1060,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
||||||
|
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||||
("moonshine", "MoonshineForConditionalGeneration"),
|
("moonshine", "MoonshineForConditionalGeneration"),
|
||||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||||
@@ -1067,7 +1068,6 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|||||||
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
||||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||||
("speecht5", "SpeechT5ForSpeechToText"),
|
("speecht5", "SpeechT5ForSpeechToText"),
|
||||||
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
|
|
||||||
("whisper", "WhisperForConditionalGeneration"),
|
("whisper", "WhisperForConditionalGeneration"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("internvl", "InternVLProcessor"),
|
("internvl", "InternVLProcessor"),
|
||||||
("janus", "JanusProcessor"),
|
("janus", "JanusProcessor"),
|
||||||
("kosmos-2", "Kosmos2Processor"),
|
("kosmos-2", "Kosmos2Processor"),
|
||||||
|
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
|
||||||
("layoutlmv2", "LayoutLMv2Processor"),
|
("layoutlmv2", "LayoutLMv2Processor"),
|
||||||
("layoutlmv3", "LayoutLMv3Processor"),
|
("layoutlmv3", "LayoutLMv3Processor"),
|
||||||
("llama4", "Llama4Processor"),
|
("llama4", "Llama4Processor"),
|
||||||
@@ -117,7 +118,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("speech_to_text", "Speech2TextProcessor"),
|
("speech_to_text", "Speech2TextProcessor"),
|
||||||
("speech_to_text_2", "Speech2Text2Processor"),
|
("speech_to_text_2", "Speech2Text2Processor"),
|
||||||
("speecht5", "SpeechT5Processor"),
|
("speecht5", "SpeechT5Processor"),
|
||||||
("stt", "KyutaiSpeechToTextProcessor"),
|
|
||||||
("trocr", "TrOCRProcessor"),
|
("trocr", "TrOCRProcessor"),
|
||||||
("tvlt", "TvltProcessor"),
|
("tvlt", "TvltProcessor"),
|
||||||
("tvp", "TvpProcessor"),
|
("tvp", "TvpProcessor"),
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
|
|||||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
||||||
2.6b-en model.
|
2.6b-en model.
|
||||||
|
|
||||||
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
|
e.g. [kyutai/stt-2.6b-en-trfs](https://huggingface.co/kyutai/stt-2.6b-en-trfs)
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
@@ -110,8 +110,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
|
|||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
|
model_type = "kyutai_speech_to_text"
|
||||||
model_type = "stt"
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
sub_configs = {"codec_config": AutoConfig}
|
sub_configs = {"codec_config": AutoConfig}
|
||||||
|
|
||||||
@@ -190,7 +190,14 @@ def write_model(
|
|||||||
print("Converting the model.")
|
print("Converting the model.")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
config = KyutaiSpeechToTextConfig()
|
config = KyutaiSpeechToTextConfig(
|
||||||
|
vocab_size=8001,
|
||||||
|
max_position_embeddings=375,
|
||||||
|
num_hidden_layers=16,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
head_dim=128,
|
||||||
|
)
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.codec_config.sliding_window = 250
|
config.codec_config.sliding_window = 250
|
||||||
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
||||||
@@ -713,7 +713,7 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention):
|
|||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
STT_ATTENTION_CLASSES = {
|
KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = {
|
||||||
"eager": KyutaiSpeechToTextAttention,
|
"eager": KyutaiSpeechToTextAttention,
|
||||||
"flash_attention_2": KyutaiSpeechToTextFlashAttention2,
|
"flash_attention_2": KyutaiSpeechToTextFlashAttention2,
|
||||||
"sdpa": KyutaiSpeechToTextSdpaAttention,
|
"sdpa": KyutaiSpeechToTextSdpaAttention,
|
||||||
@@ -726,7 +726,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.use_flexible_linear = use_flexible_linear
|
self.use_flexible_linear = use_flexible_linear
|
||||||
|
|
||||||
self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation](
|
self.self_attn = KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
|
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1169,7 +1169,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
|||||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||||
|
|
||||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
>>> model_id = "kyutai/stt-2.6b-en"
|
>>> model_id = "kyutai/stt-2.6b-en-trfs"
|
||||||
|
|
||||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||||
@@ -278,7 +278,7 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
|
|||||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||||
|
|
||||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
>>> model_id = "kyutai/stt-2.6b-en"
|
>>> model_id = "kyutai/stt-2.6b-en-trfs"
|
||||||
|
|
||||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||||
@@ -619,7 +619,7 @@ class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCa
|
|||||||
_dataset = None
|
_dataset = None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_checkpoint = "kyutai/stt-2.6b-en"
|
self.model_checkpoint = "kyutai/stt-2.6b-en-trfs"
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cleanup(torch_device, gc_collect=True)
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user