From 798f948e88fd0b93fc515ec6b96e0503b78ad6ba Mon Sep 17 00:00:00 2001
From: eustlb <94853470+eustlb@users.noreply.github.com>
Date: Wed, 7 May 2025 10:20:13 -0400
Subject: [PATCH] Add CSM model (#36719)
* draft structure
* depth decoder with forward pre hook
* full model forward draft
* draft update
* depth decoder update
* ConversationalSpeechModelForCausalLM udpates
* add generate
* max length criteria small fix
* udpate
* updates
* generation update
* update in loss compute
* conversion script
* update for correct input embeddings
* handle interleaved rope
* update
* update
* update
* support compile
* update training
* add doc
* update doc
* correct inits
* ConversationalSpeechModel -> Csm
* conf update
* name update
* tests CsmForCausalLMTest
* convert use cached_file
* conf + modeling updates
* generate utils handle third dim shape
* integration test
* modeling + conf updates
* common test handle more than 2 dims
* add nested audio list utils
* processing handle nested audio list
* csm processing draft
* mimi util
* init updates
* modular update
* convert modular
* processing update
* csm tests update
* generate tests handle third dim
* generate utils handle third dim
* propagate _get_initial_cache_position update
* tied_weight_keys update + convert correctly
* fix inputs_embeds
* revert audio nested list
* batch inference update + return audio
* audio_utils update
* processor update
* some more integration tests
* remove old test
* porcessing output labels
* improve
* fix
* update rope values with equivalent ones
* conversion update
* udpate tests
* handle depth decoder generation config
* remove default eos_token_id
* make style
* revert modeling_mimi
* add default generation_config
* remove sdpa since handled by default
* make
* fix conflict
* fix conflicts
* correct naming
* correct imports
* make
* causal -> conditional naming
* causal -> conditional naming
* auto update
* make
* make
* add doc
* test update
* fix weight init
* audio tokens offsets as buffer
* 4d mask in conditional class
* make
* doc update
* fix causal mask
* fix causal mask
* doc update
* doc update
* add processor doc
* update doc
* fix 4d causal mask
* update make_list_of_audio
* do not default to mutable
* remove duplicates
* remove useless reset_parameters
* use GradientCheckpointingLayer
* use can_return_tuple
* formatting
* prepend placeholder in _sample
* torch compile fix
* some more fixies
* convert modular
* fix
* default max_length in convert
* handle depth decoder generation config correctly
* clearer formulation
* handle output_loading_info
* handle softmax warning
* add doc
* propagate _get_initial_cache_position changes
* generation in its own module
* add processor tests
* fix compile witu cuda graphs
* fix compile with cuda graphs
* add csm.md
* include CSM loss
* doc nit
* doc nit
* doc nit
* Update docs/source/en/model_doc/csm.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* add save_audio to processor
* Update src/transformers/models/csm/modular_csm.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* doc update
* simplify audio_codes_mask computation
* doc update
* simplify loss computation
* fix static cache test
* fix
* remove comment
* simplify encoded length computation
* use hf-internal-testing
* doc update
* cast to float before numpy
* nit
* mem efficient codebook head
* nit
* cat input values with cutoffs
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---
docs/source/en/_toctree.yml | 2 +
docs/source/en/model_doc/csm.md | 377 ++++
src/transformers/audio_utils.py | 37 +-
.../generation/stopping_criteria.py | 2 +-
src/transformers/generation/utils.py | 42 +-
src/transformers/loss/loss_utils.py | 1 +
src/transformers/models/__init__.py | 1 +
.../models/auto/configuration_auto.py | 2 +
src/transformers/models/auto/modeling_auto.py | 2 +
src/transformers/models/csm/__init__.py | 28 +
.../models/csm/configuration_csm.py | 440 +++++
src/transformers/models/csm/convert_csm.py | 339 ++++
src/transformers/models/csm/generation_csm.py | 491 +++++
src/transformers/models/csm/modeling_csm.py | 1710 +++++++++++++++++
src/transformers/models/csm/modular_csm.py | 1042 ++++++++++
src/transformers/models/csm/processing_csm.py | 364 ++++
.../gpt_bigcode/modeling_gpt_bigcode.py | 6 +-
.../models/janus/modeling_janus.py | 2 +-
.../models/janus/modular_janus.py | 2 +-
src/transformers/models/mimi/modeling_mimi.py | 65 +
.../qwen2_5_omni/modeling_qwen2_5_omni.py | 4 +-
.../qwen2_5_omni/modular_qwen2_5_omni.py | 4 +-
src/transformers/processing_utils.py | 2 +-
tests/generation/test_utils.py | 102 +-
tests/models/csm/__init__.py | 0
tests/models/csm/test_modeling_csm.py | 693 +++++++
tests/models/csm/test_processor_csm.py | 140 ++
tests/test_modeling_common.py | 6 +-
utils/check_repo.py | 7 +
29 files changed, 5827 insertions(+), 86 deletions(-)
create mode 100644 docs/source/en/model_doc/csm.md
create mode 100644 src/transformers/models/csm/__init__.py
create mode 100644 src/transformers/models/csm/configuration_csm.py
create mode 100644 src/transformers/models/csm/convert_csm.py
create mode 100644 src/transformers/models/csm/generation_csm.py
create mode 100644 src/transformers/models/csm/modeling_csm.py
create mode 100644 src/transformers/models/csm/modular_csm.py
create mode 100644 src/transformers/models/csm/processing_csm.py
create mode 100644 tests/models/csm/__init__.py
create mode 100644 tests/models/csm/test_modeling_csm.py
create mode 100644 tests/models/csm/test_processor_csm.py
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 7d505721f0..b848976d02 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -825,6 +825,8 @@
title: Bark
- local: model_doc/clap
title: CLAP
+ - local: model_doc/csm
+ title: CSM
- local: model_doc/dac
title: dac
- local: model_doc/encodec
diff --git a/docs/source/en/model_doc/csm.md b/docs/source/en/model_doc/csm.md
new file mode 100644
index 0000000000..2d916da161
--- /dev/null
+++ b/docs/source/en/model_doc/csm.md
@@ -0,0 +1,377 @@
+
+
+# Csm
+
+## Overview
+
+The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
+
+**Model Architecture:**
+CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
+
+The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face.
+
+
+

+
+
+## Usage Tips
+
+### Without Conversational Context
+
+CSM can be used to simply generate speech from a text prompt:
+
+```python
+import torch
+from transformers import CsmForConditionalGeneration, AutoProcessor
+
+model_id = "eustlb/csm-1b"
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# load the model and the processor
+processor = AutoProcessor.from_pretrained(model_id)
+model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
+
+# prepare the inputs
+text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0
+inputs = processor(text, add_special_tokens=True).to(device)
+
+# another equivalent way to prepare the inputs
+conversation = [
+ {"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
+]
+inputs = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+).to(device)
+
+# infer the model
+audio = model.generate(**inputs, output_audio=True)
+processor.save_audio(audio, "example_without_context.wav")
+```
+
+### With Conversational Context
+
+CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation:
+
+```python
+import torch
+from transformers import CsmForConditionalGeneration, AutoProcessor
+from datasets import load_dataset, Audio
+
+model_id = "eustlb/csm-1b"
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# load the model and the processor
+processor = AutoProcessor.from_pretrained(model_id)
+model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
+
+# prepare the inputs
+ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+# ensure the audio is 24kHz
+ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+conversation = []
+
+# 1. context
+for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ conversation.append(
+ {
+ "role": f"{speaker_id}",
+ "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ }
+ )
+
+# 2. text prompt
+conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
+
+inputs = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+).to(device)
+
+# infer the model
+audio = model.generate(**inputs, output_audio=True)
+processor.save_audio(audio, "example_with_context.wav")
+```
+
+### Batched Inference
+
+CSM supports batched inference!
+
+```python
+import torch
+from transformers import CsmForConditionalGeneration, AutoProcessor
+from datasets import load_dataset, Audio
+
+model_id = "eustlb/csm-1b"
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# load the model and the processor
+processor = AutoProcessor.from_pretrained(model_id)
+model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
+
+# prepare the inputs
+ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+# ensure the audio is 24kHz
+ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+# here a batch with two prompts
+conversation = [
+ [
+ {
+ "role": f"{ds[0]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[0]["text"]},
+ {"type": "audio", "path": ds[0]["audio"]["array"]},
+ ],
+ },
+ {
+ "role": f"{ds[1]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[1]["text"]},
+ ],
+ },
+ ],
+ [
+ {
+ "role": f"{ds[0]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[0]["text"]},
+ ],
+ }
+ ],
+]
+inputs = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+).to(device)
+
+audio = model.generate(**inputs, output_audio=True)
+processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))])
+```
+
+### Making The Model Go Brrr
+
+CSM supports full-graph compilation with CUDA graphs!
+
+```python
+import torch
+import copy
+from transformers import CsmForConditionalGeneration, AutoProcessor
+from datasets import load_dataset
+
+model_id = "eustlb/csm-1b"
+device = "cuda"
+
+# set logs to ensure no recompilation and graph breaks
+torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
+
+# load the model and the processor
+processor = AutoProcessor.from_pretrained(model_id)
+model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
+
+# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
+model.generation_config.max_length = 250 # big enough to avoid recompilation
+model.generation_config.max_new_tokens = None # would take precedence over max_length
+model.generation_config.cache_implementation = "static"
+model.depth_decoder.generation_config.cache_implementation = "static"
+
+# generation kwargs
+gen_kwargs = {
+ "do_sample": False,
+ "depth_decoder_do_sample": False,
+ "temperature": 1.0,
+ "depth_decoder_temperature": 1.0,
+}
+
+# Define a timing decorator
+class TimerContext:
+ def __init__(self, name="Execution"):
+ self.name = name
+ self.start_event = None
+ self.end_event = None
+
+ def __enter__(self):
+ # Use CUDA events for more accurate GPU timing
+ self.start_event = torch.cuda.Event(enable_timing=True)
+ self.end_event = torch.cuda.Event(enable_timing=True)
+ self.start_event.record()
+ return self
+
+ def __exit__(self, *args):
+ self.end_event.record()
+ torch.cuda.synchronize()
+ elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
+ print(f"{self.name} time: {elapsed_time:.4f} seconds")
+
+# prepare the inputs
+ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+
+conversation = [
+ {
+ "role": f"{ds[0]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[0]["text"]},
+ {"type": "audio", "path": ds[0]["audio"]["array"]},
+ ],
+ },
+ {
+ "role": f"{ds[1]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[1]["text"]},
+ {"type": "audio", "path": ds[1]["audio"]["array"]},
+ ],
+ },
+ {
+ "role": f"{ds[2]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[2]["text"]},
+ ],
+ },
+]
+
+padded_inputs_1 = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+).to(device)
+
+print("\n" + "="*50)
+print("First generation - compiling and recording CUDA graphs...")
+with TimerContext("First generation"):
+ _ = model.generate(**padded_inputs_1, **gen_kwargs)
+print("="*50)
+
+print("\n" + "="*50)
+print("Second generation - fast !!!")
+with TimerContext("Second generation"):
+ _ = model.generate(**padded_inputs_1, **gen_kwargs)
+print("="*50)
+
+# now with different inputs
+conversation = [
+ {
+ "role": f"{ds[0]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[2]["text"]},
+ {"type": "audio", "path": ds[2]["audio"]["array"]},
+ ],
+ },
+ {
+ "role": f"{ds[1]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[3]["text"]},
+ {"type": "audio", "path": ds[3]["audio"]["array"]},
+ ],
+ },
+ {
+ "role": f"{ds[2]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[4]["text"]},
+ ],
+ },
+]
+padded_inputs_2 = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+).to(device)
+
+print("\n" + "="*50)
+print("Generation with other inputs!")
+with TimerContext("Generation with different inputs"):
+ _ = model.generate(**padded_inputs_2, **gen_kwargs)
+print("="*50)
+```
+
+### Training
+
+CSM Transformers integration supports training!
+
+```python
+from transformers import CsmForConditionalGeneration, AutoProcessor
+from datasets import load_dataset, Audio
+
+model_id = "eustlb/csm-1b"
+device = "cuda"
+
+# load the model and the processor
+processor = AutoProcessor.from_pretrained(model_id)
+model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
+model.train()
+
+ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+# ensure the audio is 24kHz
+ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+conversation = []
+
+# context
+for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ conversation.append(
+ {
+ "role": f"{speaker_id}",
+ "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ }
+ )
+
+inputs = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ output_labels=True,
+).to(device)
+
+out = model(**inputs)
+out.loss.backward()
+```
+
+This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
+The original code can be found [here](https://github.com/SesameAILabs/csm).
+
+
+## CsmConfig
+
+[[autodoc]] CsmConfig
+
+## CsmDepthDecoderConfig
+
+[[autodoc]] CsmDepthDecoderConfig
+
+## CsmProcessor
+
+[[autodoc]] CsmProcessor
+ - __call__
+
+## CsmForConditionalGeneration
+
+[[autodoc]] CsmForConditionalGeneration
+ - forward
+ - generate
+
+## CsmDepthDecoderForCausalLM
+
+[[autodoc]] CsmDepthDecoderForCausalLM
+
+## CsmDepthDecoderModel
+
+[[autodoc]] CsmDepthDecoderModel
+
+## CsmBackboneModel
+
+[[autodoc]] CsmBackboneModel
diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py
index fced0e5758..e980d4cbef 100644
--- a/src/transformers/audio_utils.py
+++ b/src/transformers/audio_utils.py
@@ -24,7 +24,12 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import requests
-from .utils import is_librosa_available, requires_backends
+from .utils import (
+ is_librosa_available,
+ is_numpy_array,
+ is_torch_tensor,
+ requires_backends,
+)
if is_librosa_available():
@@ -69,6 +74,36 @@ AudioInput = Union[
]
+def is_valid_audio(audio):
+ return is_numpy_array(audio) or is_torch_tensor(audio)
+
+
+def is_valid_list_of_audio(audio):
+ return audio and all(is_valid_audio(audio_i) for audio_i in audio)
+
+
+def make_list_of_audio(
+ audio: Union[list[AudioInput], AudioInput],
+) -> AudioInput:
+ """
+ Ensure that the output is a list of audio.
+ Args:
+ audio (`Union[List[AudioInput], AudioInput]`):
+ The input audio.
+ Returns:
+ list: A list of audio.
+ """
+ # If it's a list of audios, it's already in the right format
+ if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
+ return audio
+
+ # If it's a single audio, convert it to a list of
+ if is_valid_audio(audio):
+ return [audio]
+
+ raise ValueError("Invalid input type. Must be a single audio or a list of audio")
+
+
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""
Convert frequency from hertz to mels.
diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py
index 4627aeb970..6e47b041ab 100644
--- a/src/transformers/generation/stopping_criteria.py
+++ b/src/transformers/generation/stopping_criteria.py
@@ -73,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
- cur_len = input_ids.shape[-1]
+ cur_len = input_ids.shape[1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 36e65949a4..3c7e83369a 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -563,7 +563,7 @@ class GenerationMixin:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
else:
- batch_size, sequence_length = model_inputs[input_ids_key].shape
+ batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
@@ -1708,7 +1708,7 @@ class GenerationMixin:
return generation_config, model_kwargs
- def _get_initial_cache_position(self, input_ids, model_kwargs):
+ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
@@ -1718,7 +1718,7 @@ class GenerationMixin:
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else:
- cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
+ cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
past_length = 0
if model_kwargs.get("past_key_values") is not None:
@@ -2332,7 +2332,7 @@ class GenerationMixin:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
- input_ids_length = input_ids.shape[-1]
+ input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
@@ -2805,9 +2805,9 @@ class GenerationMixin:
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
- batch_size = input_ids.shape[0]
+ batch_size, cur_length = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
this_peer_finished = False
@@ -3016,9 +3016,9 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
- batch_size = input_ids.shape[0]
+ batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# Create cosine_matrix_mask based on the attention_mask
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
@@ -3428,10 +3428,10 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
- batch_size, cur_len = input_ids.shape
+ batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
@@ -3834,7 +3834,7 @@ class GenerationMixin:
num_beams = generation_config.num_beams
num_return_sequences = generation_config.num_return_sequences
- batch_size_unflattened, cur_len = input_ids.shape
+ batch_size_unflattened, cur_len = input_ids.shape[:2]
batch_size = batch_size_unflattened // num_beams
# TODO (joao): standardize special cases
if self.__class__.__name__ == "MoshiDepthDecoder":
@@ -3857,7 +3857,7 @@ class GenerationMixin:
dim=0,
).to(input_ids.device)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
@@ -4156,7 +4156,7 @@ class GenerationMixin:
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
@@ -4190,7 +4190,7 @@ class GenerationMixin:
this_peer_finished = False
- decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
+ decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
@@ -4444,8 +4444,8 @@ class GenerationMixin:
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
- batch_beam_size, cur_len = input_ids.shape
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ batch_beam_size, cur_len = input_ids.shape[:2]
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
@@ -4477,7 +4477,7 @@ class GenerationMixin:
this_peer_finished = False
- decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
+ decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@@ -4698,14 +4698,14 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
- batch_size = input_ids.shape[0]
+ batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
- cur_len = input_ids.shape[-1]
+ cur_len = input_ids.shape[1]
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
@@ -4795,7 +4795,7 @@ class GenerationMixin:
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
- new_cur_len = input_ids.shape[-1]
+ new_cur_len = input_ids.shape[1]
# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py
index 12da2f3d7a..aad42d3fd5 100644
--- a/src/transformers/loss/loss_utils.py
+++ b/src/transformers/loss/loss_utils.py
@@ -158,4 +158,5 @@ LOSS_MAPPING = {
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
"DFineForObjectDetection": DFineForObjectDetectionLoss,
+ "CsmForConditionalGeneration": ForCausalLMLoss,
}
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 1be149faf6..8d713b482b 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
from .convnextv2 import *
from .cpm import *
from .cpmant import *
+ from .csm import *
from .ctrl import *
from .cvt import *
from .d_fine import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 66ed95acad..01c58a5062 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -80,6 +80,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextConfig"),
("convnextv2", "ConvNextV2Config"),
("cpmant", "CpmAntConfig"),
+ ("csm", "CsmConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("d_fine", "DFineConfig"),
@@ -437,6 +438,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("convnextv2", "ConvNeXTV2"),
("cpm", "CPM"),
("cpmant", "CPM-Ant"),
+ ("csm", "CSM"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("d_fine", "D-FINE"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index b8ec5c6ccb..b196a7718f 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -78,6 +78,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("cpmant", "CpmAntModel"),
+ ("csm", "CsmForConditionalGeneration"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("d_fine", "DFineModel"),
@@ -1446,6 +1447,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
[
# Model for Text-To-Waveform mapping
("bark", "BarkModel"),
+ ("csm", "CsmForConditionalGeneration"),
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
("musicgen", "MusicgenForConditionalGeneration"),
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
diff --git a/src/transformers/models/csm/__init__.py b/src/transformers/models/csm/__init__.py
new file mode 100644
index 0000000000..59468442b5
--- /dev/null
+++ b/src/transformers/models/csm/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_csm import *
+ from .modeling_csm import *
+ from .processing_csm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/csm/configuration_csm.py b/src/transformers/models/csm/configuration_csm.py
new file mode 100644
index 0000000000..e6d6d2e27c
--- /dev/null
+++ b/src/transformers/models/csm/configuration_csm.py
@@ -0,0 +1,440 @@
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class CsmDepthDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CsmDepthDecoderModel`]. It is used to instantiate an CSM depth decoder
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of the csm-1b.
+
+ e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ num_codebooks (`int`, *optional*, defaults to 32):
+ Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
+ backbone_hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations of the backbone model used with this depth decoder.
+ vocab_size (`int`, *optional*, defaults to 2051):
+ Vocabulary size of the CsmDepthDecoder model. Defines the number of different audio tokens that can be represented by each codebook.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 33):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 2050):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+ rope_theta (`float`, *optional*, defaults to 500000):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+
+ ```python
+ >>> from transformers import CsmDepthDecoder, CsmDepthDecoderConfig
+
+ >>> # Initializing a CsmDepthDecoder
+ >>> configuration = CsmDepthDecoderConfig()
+ >>> model = CsmDepthDecoderModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "csm_depth_decoder_model"
+ base_config_key = "depth_decoder_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ num_codebooks=32,
+ backbone_hidden_size=2048,
+ vocab_size=2051,
+ hidden_size=1024,
+ intermediate_size=8192,
+ num_hidden_layers=4,
+ num_attention_heads=8,
+ num_key_value_heads=2,
+ hidden_act="silu",
+ max_position_embeddings=33,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=None,
+ eos_token_id=None,
+ rope_theta=500000,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ **kwargs,
+ ):
+ if kwargs.pop("tie_word_embeddings", False):
+ raise ValueError("`tie_word_embeddings=True` is not supported for CsmDepthDecoderConfig")
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=False,
+ **kwargs,
+ )
+ self.num_codebooks = num_codebooks
+ self.vocab_size = vocab_size
+ self.backbone_hidden_size = backbone_hidden_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+class CsmConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`CsmForConditionalGeneration`]. It is used to instantiate an CSM
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the csm-1b.
+
+ e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_codebooks (`int`, *optional*, defaults to 32):
+ Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
+ vocab_size (`int`, *optional*, defaults to 2051):
+ Vocabulary size of the Csm model. Defines the number of different audio tokens that can be represented by each codebook.
+ text_vocab_size (`int`, *optional*, defaults to 128256):
+ Vocabulary size of the text input for the Csm model. Defines the number of different text tokens that can be represented.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations of the backbone model.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations of the backbone model.
+ num_hidden_layers (`int`, *optional*, defaults to 16):
+ Number of hidden layers in the backbone model Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the backbone model Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf).
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the backbone model Transformer decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 128002):
+ Padding token id.
+ codebook_pad_token_id (`int`, *optional*, defaults to 2050):
+ Padding token id for codebook tokens.
+ codebook_eos_token_id (`int`, *optional*, defaults to 0):
+ End of stream token id for codebook tokens.
+ bos_token_id (`int`, *optional*, defaults to 128000):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*):
+ End of stream token id.
+ audio_token_id (`int`, *optional*, defaults to 128002):
+ Audio token id in the text input.
+ audio_eos_token_id (`int`, *optional*, defaults to 128003):
+ End of stream token id for audio in the text input.
+ rope_theta (`float`, *optional*, defaults to 500000):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*, defaults to `{'factor': 32.0, 'high_freq_factor': 0.5, 'low_freq_factor': 0.125, 'original_max_position_embeddings': 1024, 'rope_type': 'llama3'}`):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ head_dim (`int`, *optional*):
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
+ tie_codebooks_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie the codebook tokens embeddings of the backbone model to the codebook tokens embeddings of the depth decoder.
+ depth_decoder_config (`CsmDepthDecoderConfig`, *optional*):
+ Configuration for the depth decoder.
+ codec_config (`PretrainedConfig`, *optional*):
+ Configuration for the codec.
+
+ ```python
+ >>> from transformers import CsmForConditionalGeneration, CsmConfig
+
+ >>> # Initializing a CsmConfig
+ >>> configuration = CsmConfig()
+
+ >>> # Initializing a model
+ >>> model = CsmForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "csm"
+ base_config_key = "csm_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {
+ "codec_config": AutoConfig,
+ "depth_decoder_config": CsmDepthDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ num_codebooks=32,
+ vocab_size=2051,
+ text_vocab_size=128256,
+ hidden_size=2048,
+ intermediate_size=8192,
+ num_hidden_layers=16,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=128002,
+ codebook_pad_token_id=2050,
+ codebook_eos_token_id=0,
+ bos_token_id=128000,
+ eos_token_id=None,
+ audio_token_id=128002,
+ audio_eos_token_id=128003,
+ rope_theta=500000,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ head_dim=None,
+ tie_codebooks_embeddings=True,
+ depth_decoder_config=None,
+ codec_config=None,
+ **kwargs,
+ ):
+ if kwargs.pop("tie_word_embeddings", False):
+ raise ValueError("`tie_word_embeddings=True` is not supported for CsmConfig")
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=False,
+ **kwargs,
+ )
+
+ if depth_decoder_config is None:
+ self.depth_decoder_config = CsmDepthDecoderConfig()
+ logger.info("depth_decoder_config is None, using default depth decoder config.")
+ elif isinstance(depth_decoder_config, dict):
+ self.depth_decoder_config = CsmDepthDecoderConfig(**depth_decoder_config)
+ elif isinstance(depth_decoder_config, CsmDepthDecoderConfig):
+ self.depth_decoder_config = depth_decoder_config
+
+ if codec_config is None:
+ self.codec_config = AutoConfig.for_model("mimi")
+ logger.info("codec_config is None, using default audio encoder config.")
+ elif isinstance(codec_config, dict):
+ self.codec_config = AutoConfig.for_model(**codec_config)
+ elif isinstance(codec_config, PretrainedConfig):
+ self.codec_config = codec_config
+
+ self.text_vocab_size = text_vocab_size
+ self.num_codebooks = num_codebooks
+ self.audio_token_id = audio_token_id
+ self.audio_eos_token_id = audio_eos_token_id
+ self.codebook_pad_token_id = codebook_pad_token_id
+ self.codebook_eos_token_id = codebook_eos_token_id
+ self.tie_codebooks_embeddings = tie_codebooks_embeddings
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+
+__all__ = [
+ "CsmDepthDecoderConfig",
+ "CsmConfig",
+]
diff --git a/src/transformers/models/csm/convert_csm.py b/src/transformers/models/csm/convert_csm.py
new file mode 100644
index 0000000000..dc84e2cf3d
--- /dev/null
+++ b/src/transformers/models/csm/convert_csm.py
@@ -0,0 +1,339 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import gc
+import os
+import re
+
+import torch
+from tokenizers.processors import TemplateProcessing
+
+from transformers import (
+ AutoFeatureExtractor,
+ AutoTokenizer,
+ CsmConfig,
+ CsmDepthDecoderConfig,
+ CsmForConditionalGeneration,
+ CsmProcessor,
+ MimiModel,
+)
+from transformers.utils.hub import cached_file
+
+
+# fmt: off
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ r"backbone\.layers\.(\d+)": r"backbone_model.layers.\1",
+ r"decoder\.layers\.(\d+)": r"depth_decoder.model.layers.\1",
+
+ r"attn": r"self_attn",
+ r"output_proj": r"o_proj",
+ r"w1": r"gate_proj",
+ r"w2": r"down_proj",
+ r"w3": r"up_proj",
+
+ r"text_embeddings": r"embed_text_tokens",
+ r"audio_embeddings": r"backbone_model.embed_tokens.embed_audio_tokens",
+
+ r"codebook0_head": r"lm_head",
+ r"audio_head": r"depth_decoder.codebooks_head.weight",
+ r"projection": r"depth_decoder.model.inputs_embeds_projector",
+
+ r"sa_norm.scale": r"input_layernorm.weight",
+ r"mlp_norm.scale": r"post_attention_layernorm.weight",
+ r"decoder.norm.scale": r"depth_decoder.model.norm.weight",
+ r"backbone.norm.scale": r"backbone_model.norm.weight",
+}
+# fmt: on
+
+
+def permute_for_rope(input_tensor, n_heads, dim1, dim2):
+ """
+ When you go from the complex ROPE formulation to sin and cos one, you need
+ to permute the query and key weights (to avoid doing it on the fly)
+ """
+ input_tensor = input_tensor.reshape(dim1, dim2)
+ input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
+ input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
+ return input_tensor
+
+
+def convert_key(key, mapping):
+ for pattern, replacement in mapping.items():
+ key = re.sub(pattern, replacement, key)
+ return key
+
+
+def write_model(
+ input_path_or_repo,
+ model_name,
+ codec_model_path_or_repo,
+ output_dir,
+ safe_serialization=True,
+):
+ print("Converting the model.")
+ os.makedirs(output_dir, exist_ok=True)
+
+ codec_model = MimiModel.from_pretrained(codec_model_path_or_repo)
+ codec_model.config._attn_implementation_autoset = False
+
+ # prepare rope scaling args: the model uses originally
+ # 1 - for the depth decoder
+ # rope_theta=500000,
+ # rope_scaling={
+ # "factor": 32.0,
+ # "high_freq_factor": 4.0,
+ # "low_freq_factor": 1.0,
+ # "original_max_position_embeddings": 8192,
+ # "rope_type": "llama3",
+ # },
+ # 2 - for the backbone
+ # rope_theta=500000,
+ # rope_scaling={
+ # "factor": 32.0,
+ # "high_freq_factor": 4.0,
+ # "low_freq_factor": 1.0,
+ # "original_max_position_embeddings": 8192,
+ # "rope_type": "llama3",
+ # },
+ #
+ # Yet we want to use max_position_embeddings=32, resp. 2048
+ # This will throw warning as we would have original_max_position_embeddings >= max_position_embeddings
+ # Therefore, we convert values to equivalent ones
+
+ depth_decoder_config = CsmDepthDecoderConfig(
+ rope_scaling={
+ "factor": 32.0,
+ "high_freq_factor": 0.0078125,
+ "low_freq_factor": 0.001953125,
+ "original_max_position_embeddings": 16,
+ "rope_type": "llama3",
+ },
+ )
+
+ config = CsmConfig(
+ codec_config=codec_model.config,
+ depth_decoder_config=depth_decoder_config,
+ rope_scaling={
+ "factor": 32.0,
+ "high_freq_factor": 0.5,
+ "low_freq_factor": 0.125,
+ "original_max_position_embeddings": 1024,
+ "rope_type": "llama3",
+ },
+ )
+
+ params = {
+ "backbone": {
+ "num_attention_heads": config.num_attention_heads,
+ "num_key_value_heads": config.num_key_value_heads,
+ "dim_per_head": config.head_dim,
+ "key_value_dim": config.head_dim * config.num_key_value_heads,
+ "dim": config.hidden_size,
+ },
+ "depth_decoder": {
+ "num_attention_heads": config.depth_decoder_config.num_attention_heads,
+ "num_key_value_heads": config.depth_decoder_config.num_key_value_heads,
+ "dim_per_head": config.depth_decoder_config.head_dim,
+ "key_value_dim": config.depth_decoder_config.head_dim * config.depth_decoder_config.num_key_value_heads,
+ "dim": config.depth_decoder_config.hidden_size,
+ },
+ }
+
+ model_path = cached_file(
+ input_path_or_repo,
+ model_name,
+ )
+ print(f"Fetching all parameters from the checkpoint at {model_path}...")
+ loaded = torch.load(model_path, map_location="cpu")
+
+ print("Converting model...")
+ state_dict = {}
+
+ # -----------------------
+ # convert parameter names
+ # -----------------------
+
+ # Add codec_model. prefix to every key in the codec model state dict
+ codec_state_dict = {f"codec_model.{k}": v for k, v in codec_model.state_dict().items()}
+ state_dict.update(codec_state_dict)
+
+ for key, value in loaded.items():
+ new_key = convert_key(key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
+ current_parameter = value
+
+ # Post-process the current_parameter.
+ if re.search("(k|q)_proj.weight", new_key):
+ params_keys = "backbone" if "backbone" in new_key else "depth_decoder"
+ if "q_proj" in new_key:
+ num_heads = params[params_keys]["num_attention_heads"]
+ dim_per_head = params[params_keys]["dim_per_head"]
+ param_dim = params[params_keys]["dim"]
+ dim = params[params_keys]["dim"]
+ else:
+ num_heads = params[params_keys]["num_key_value_heads"]
+ dim_per_head = params[params_keys]["dim_per_head"]
+ param_dim = params[params_keys]["key_value_dim"]
+ dim = params[params_keys]["dim"]
+
+ current_parameter = permute_for_rope(value, num_heads, param_dim, dim)
+ state_dict[new_key] = current_parameter.reshape(num_heads * dim_per_head, dim)
+
+ state_dict[new_key] = current_parameter
+
+ # add the depth decoder embed audio tokens weights, latter tied to the backbone embed audio tokens weights
+ state_dict["depth_decoder.model.embed_tokens.weight"] = state_dict[
+ "backbone_model.embed_tokens.embed_audio_tokens.weight"
+ ].clone()
+ del loaded
+ gc.collect()
+
+ # -------------------------
+ # load the weights and save
+ # -------------------------
+
+ print("Loading the checkpoint in a Csm model.")
+ with torch.device("meta"):
+ model = CsmForConditionalGeneration(config)
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ print("Checkpoint loaded successfully.")
+ del model.config._name_or_path
+
+ # default generation config
+ model.generation_config._from_model_config = False
+ model.generation_config.max_new_tokens = 125
+ model.generation_config.do_sample = True
+ model.generation_config.top_k = 50
+ model.generation_config.temperature = 0.9
+ model.generation_config.depth_decoder_do_sample = True
+ model.generation_config.depth_decoder_top_k = 50
+ model.generation_config.depth_decoder_temperature = 0.9
+
+ print("Saving the model.")
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ del state_dict, model
+
+ # Safety check: reload the converted model
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ CsmForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
+ print("Model reloaded successfully.")
+
+
+def write_tokenizer(output_dir):
+ # from https://github.com/SesameAILabs/csm/blob/2d720827843b653c4d67bb4445b1c0a4f59e646f/generator.py#L22-L36
+ def load_llama3_tokenizer():
+ """
+ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
+ """
+ tokenizer_name = "meta-llama/Llama-3.2-1B"
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ bos = tokenizer.bos_token
+ eos = tokenizer.eos_token
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
+ single=f"{bos}:0 $A:0 {eos}:0",
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
+ )
+
+ return tokenizer
+
+ tokenizer = load_llama3_tokenizer()
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.save_pretrained(output_dir)
+
+ # manually modify in tokenizer_config.json
+ # "128002": {
+ # "content": "<|AUDIO|>",
+ # ...
+ # }
+ # "128003": {
+ # "content": "<|audio_eos|>",
+ # ...
+ # }
+ print(
+ "Tokenizer saved successfully. Please manually modify in tokenizer_config.json AND tokenizer.json as follows: "
+ )
+ print("""
+ # "128002": {
+ # "content": "<|AUDIO|>",
+ # ...
+ # }
+ # "128003": {
+ # "content": "<|audio_eos|>",
+ # ...
+ # }
+ """)
+
+
+def write_processor(output_dir, codec_model_path_or_repo):
+ chat_template = "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"
+ tokenizer = AutoTokenizer.from_pretrained(output_dir)
+ feature_extractor = AutoFeatureExtractor.from_pretrained(codec_model_path_or_repo)
+
+ processor = CsmProcessor(
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ chat_template=chat_template,
+ )
+
+ processor.save_pretrained(output_dir)
+ print("Processor saved successfully.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert Csm weights to HuggingFace format")
+ parser.add_argument(
+ "--input_path_or_repo",
+ type=str,
+ required=True,
+ help="Path or repo containing Csm weights",
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ required=True,
+ help="Name of the model in input_path_or_repo",
+ )
+ parser.add_argument(
+ "--codec_model_path_or_repo",
+ type=str,
+ required=True,
+ help="Path or repo containing the codec model",
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Location to write HF model and tokenizer",
+ )
+ parser.add_argument(
+ "--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
+ )
+ args = parser.parse_args()
+
+ write_model(
+ args.input_path_or_repo,
+ args.model_name,
+ args.codec_model_path_or_repo,
+ output_dir=args.output_dir,
+ safe_serialization=args.safe_serialization,
+ )
+
+ write_tokenizer(args.output_dir)
+
+ write_processor(args.output_dir, args.codec_model_path_or_repo)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/csm/generation_csm.py b/src/transformers/models/csm/generation_csm.py
new file mode 100644
index 0000000000..b1c2cd920d
--- /dev/null
+++ b/src/transformers/models/csm/generation_csm.py
@@ -0,0 +1,491 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...generation import (
+ GenerateDecoderOnlyOutput,
+ GenerationConfig,
+ GenerationMixin,
+ GenerationMode,
+)
+from ...generation.logits_process import LogitsProcessorList
+from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
+from ...generation.utils import GenerateNonBeamOutput
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from ...generation.streamers import BaseStreamer
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class CsmGenerateOutput(GenerateDecoderOnlyOutput):
+ """
+ Outputs of CsmForConditionalGeneration.generate.
+
+ Args:
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
+ if all batches finished early due to the `eos_token_id`.
+ scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
+ Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
+ hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
+ past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
+ Returns the model cache, used to speed up decoding. Different models have a different cache format, check
+ audio (`list(torch.FloatTensor)` of length `batch_size`):
+ The generated audio.
+ """
+
+ audio: Optional[List[torch.Tensor]] = None
+
+
+class CsmGenerationMixin(GenerationMixin):
+ def _get_stopping_criteria(
+ self,
+ *args,
+ **kwargs,
+ ) -> StoppingCriteriaList:
+ criteria = super()._get_stopping_criteria(*args, **kwargs)
+
+ kept_criteria = StoppingCriteriaList()
+ for criterion in criteria:
+ if not isinstance(criterion, MaxLengthCriteria):
+ logger.warning(
+ f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
+ )
+ else:
+ kept_criteria.append(criterion)
+ return kept_criteria
+
+ def _prepare_generation_config(
+ self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
+ ) -> Tuple[GenerationConfig, Dict]:
+ """
+ This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
+ It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
+ """
+ # extract depth decoder kwargs and remove them from the main kwargs
+ depth_decoder_kwargs = {
+ k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
+ }
+
+ # remove the depth decoder keys from the original kwargs
+ kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
+
+ # initialize the generation config
+ generation_config, model_kwargs = super()._prepare_generation_config(
+ generation_config, use_model_defaults, **kwargs
+ )
+ self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
+
+ # ensure the depth decoder generation config is valid
+ depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
+ self.config.num_codebooks - 1
+ )
+ depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
+ self.config.num_codebooks - 1
+ )
+
+ if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
+ raise ValueError(
+ f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
+ )
+ elif self.depth_decoder.generation_config.return_dict_in_generate:
+ logger.warning(
+ "depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
+ )
+ self.depth_decoder.generation_config.return_dict_in_generate = False
+
+ self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
+ self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
+
+ # Monkey patch the get_generation_mode method to support CSM model
+ original_get_generation_mode = generation_config.get_generation_mode
+
+ def patched_get_generation_mode(assistant_model=None):
+ generation_mode = original_get_generation_mode(assistant_model)
+ if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
+ raise ValueError(
+ f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
+ )
+
+ return generation_mode
+
+ generation_config.get_generation_mode = patched_get_generation_mode
+
+ return generation_config, model_kwargs
+
+ def _sample(
+ self,
+ input_ids: torch.LongTensor,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ streamer: Optional["BaseStreamer"],
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+ """
+ This method overrides [~generation.utils.GenerationMixin._sample].
+ To ease maintenance, modifications are marked with the comment "Csm specific".
+
+ Indeed, Csm model requires a custom generation sampling step:
+ 1. Infer the backbone model to sample the first codebook token
+ 2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
+ 3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
+ 4. Repeat until stopping criteria is met
+
+ Csm supports two stopping criterias:
+ - stop when the generated sequence is at max_length
+ - stop when all the generated codebook tokens are the codebook_eos_token_id
+ """
+ # init values
+ # *************** Csm specific ***************
+ pad_token_id = self.config.codebook_pad_token_id
+ has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
+ # ============================================
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+ do_sample = generation_config.do_sample
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # keep track of which sequences are already finished
+ batch_size, cur_len = input_ids.shape[:2]
+ this_peer_finished = False
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
+
+ # *************** Csm specific ***************
+ if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
+ # in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
+ # we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
+ for criterion in stopping_criteria:
+ if isinstance(criterion, MaxLengthCriteria):
+ criterion.max_length -= cur_len
+ # ============================================
+
+ model_forward = self.__call__
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
+ if compile_forward:
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
+ model_forward = self.get_compiled_call(generation_config.compile_config)
+
+ is_prefill = True
+ while self._has_unfinished_sequences(
+ this_peer_finished,
+ synced_gpus,
+ device=input_ids.device,
+ ):
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # prepare variable output controls (note: some models won't accept all output controls)
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
+ # *************** Csm specific ***************
+ model_inputs.update({"output_hidden_states": True})
+ # ============================================
+
+ if is_prefill:
+ outputs = self(**model_inputs, return_dict=True)
+ is_prefill = False
+ else:
+ outputs = model_forward(**model_inputs, return_dict=True)
+
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ )
+ if synced_gpus and this_peer_finished:
+ continue
+
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
+ # (the clone itself is always small)
+ next_token_logits = outputs.logits[:, -1, :].clone().float()
+ next_token_logits = next_token_logits.to(input_ids.device)
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (outputs.attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (outputs.hidden_states,)
+
+ # token selection
+ if do_sample:
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
+
+ # *************** Csm specific ***************
+ # infer the depth decoder
+ first_codebook_ids = next_tokens[:, None]
+ # adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
+ depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
+ backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
+
+ depth_decoder_outputs = self.depth_decoder.generate(
+ input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
+ )
+ codebook_ids = (
+ depth_decoder_outputs
+ if isinstance(depth_decoder_outputs, torch.Tensor)
+ else depth_decoder_outputs.sequences
+ )
+ # remove the place holder in position 0
+ codebook_ids = codebook_ids[:, 1:]
+ next_tokens = codebook_ids
+
+ # finished sentences should have their next token be a padding token
+ if has_eos_stopping_criteria:
+ next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
+ 1 - unfinished_sequences.unsqueeze(-1)
+ )
+
+ # update generated ids, model inputs, and length for next step
+ if input_ids.ndim == 2:
+ input_ids = next_tokens[:, None, :]
+ else:
+ input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
+ # ============================================
+
+ if streamer is not None:
+ streamer.put(next_tokens.cpu())
+
+ # *************** Csm specific ***************
+ # for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
+ unfinished_sequences = unfinished_sequences & ~(
+ input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
+ ).all(-1)
+ # ============================================
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
+ this_peer_finished = unfinished_sequences.max() == 0
+ cur_len += 1
+
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
+ del outputs
+
+ # *************** Csm specific ***************
+ del depth_decoder_outputs
+ # ============================================
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids
+
+ def generate(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ output_audio: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+ r"""
+ This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
+ Indeed, Csm model requires a custom generation sampling step:
+ 1. Infer the backbone model to sample the first codebook token
+ 2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
+ 3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
+ 4. Repeat until stopping criteria is met
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
+
+
+ Parameters:
+ inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
+ The sequence used as a prompt for the backbone model.
+ input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
+ The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
+ These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
+ input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
+ Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
+ If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
+ where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
+ the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
+ generation_config ([`~generation.GenerationConfig`], *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which has the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complements the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
+ sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
+ intended for advanced users.
+ synced_gpus (`bool`, *optional*):
+ Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
+ to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
+ deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ output_audio (`bool`, *optional*):
+ Whether to return the generated audio.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
+
+ Return:
+ [`CsmGenerateOutput`] or `torch.LongTensor` or `List[torch.FloatTensor]`: A [`CsmGenerateOutput`]
+ (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
+ or a `List[torch.FloatTensor]` otherwise.
+
+ Example:
+
+ ```python
+ >>> from transformers import CsmProcessor, CsmForConditionalGeneration
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "eustlb/csm-1b"
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ >>> # ensure the audio is 24kHz
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+ >>> conversation = []
+ >>> # prepare a conversation with text and corresponding audio
+ >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ ... conversation.append(
+ ... {
+ ... "role": f"{speaker_id}",
+ ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ ... }
+ ... )
+
+ >>> # text prompt
+ >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
+
+ >>> inputs = processor.apply_chat_template(
+ ... conversation,
+ ... tokenize=True,
+ ... return_dict=True,
+ ... ).to(torch_device)
+
+ >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+ >>> audio = model.generate(**inputs, output_audio=True)
+ >>> processor.save_audio(audio, "output.wav")
+ ```
+ """
+ generate_output = super().generate(
+ input_ids=input_ids,
+ input_values=input_values,
+ input_values_cutoffs=input_values_cutoffs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **kwargs,
+ )
+
+ generate_returned_dict = not isinstance(generate_output, torch.Tensor)
+ audio = None
+ if output_audio:
+ generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
+
+ # infer the codec model
+ audio = []
+ with torch.no_grad():
+ # =======================================
+ # TODO: @eustlb, this should be batched !!!
+ # but requires making sure batched inference of the codec model works as intended
+ for audio_codes_batch in generated_audio_codes:
+ eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
+ if eos_idxs.numel() != 0:
+ cutoff_idx = eos_idxs.min()
+ else:
+ cutoff_idx = audio_codes_batch.shape[1]
+
+ audio_codes_batch = audio_codes_batch[:cutoff_idx]
+ codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
+ audio.append(codec_decode_output.audio_values[0, 0])
+ # =======================================
+
+ if generate_returned_dict:
+ return CsmGenerateOutput(audio=audio, **generate_output)
+ elif output_audio:
+ return audio
+ else:
+ return generate_output
diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py
new file mode 100644
index 0000000000..03cbc07df4
--- /dev/null
+++ b/src/transformers/models/csm/modeling_csm.py
@@ -0,0 +1,1710 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/csm/modular_csm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_csm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ LossKwargs,
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ is_torch_flex_attn_available,
+ logging,
+ replace_return_docstrings,
+)
+from ..auto import AutoModel
+from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
+from .generation_csm import CsmGenerationMixin
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+_CONFIG_FOR_DOC = "CsmConfig"
+
+
+@dataclass
+class CsmOutputWithPast(ModelOutput):
+ """
+ Base class for the model autoregressive outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the depth decoder model.
+ depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
+ depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+ depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+ backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the backbone model.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_loss: Optional[torch.FloatTensor] = None
+ depth_decoder_logits: torch.FloatTensor = None
+ depth_decoder_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ depth_decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ backbone_loss: Optional[torch.FloatTensor] = None
+
+
+START_DOCSTRING_BASE = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`{config_class}`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+CSM_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmConfig"))
+
+
+@add_start_docstrings(
+ "The bare Csm Model outputting raw hidden-states without any specific head on top.",
+ CSM_START_DOCSTRING,
+)
+class CsmPreTrainedModel(PreTrainedModel):
+ config_class = CsmConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["CsmDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # does not because of Mimi codec model
+ # _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, CsmCodebooksHead):
+ num_codebooks = module.num_codebooks
+ for i in range(num_codebooks - 1):
+ module.weight.data[i].normal_(mean=0.0, std=std)
+ elif isinstance(module, CsmRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class CsmRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ CsmRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class CsmRotaryEmbedding(nn.Module):
+ def __init__(self, config: CsmConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class CsmMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class CsmAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: CsmConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class CsmDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: CsmConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = CsmAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = CsmMLP(config)
+ self.input_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+CSM_DEPTH_DECODER_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmDepthDecoderConfig"))
+
+
+INPUTS_DOCSTRING_BASE = r"""
+ Args:
+ {input_ids_docstring}
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+DEPTH_DECODER_INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)"""
+
+
+CSM_DEPTH_DECODER_INPUTS_DOCSTRING = r"""{}""".format(
+ INPUTS_DOCSTRING_BASE.format(input_ids_docstring=DEPTH_DECODER_INPUT_IDS_DOCSTRING)
+)
+
+
+@add_start_docstrings(
+ "The bare CsmDepthDecoderModel outputting raw hidden-states without any specific head on top.",
+ CSM_DEPTH_DECODER_START_DOCSTRING,
+)
+class CsmDepthDecoderModel(CsmPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`]
+
+ Args:
+ config: CsmDepthDecoderConfig
+ """
+
+ config_class = CsmDepthDecoderConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
+ self.layers = nn.ModuleList(
+ [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = CsmRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+ self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ """
+ if position_ids is not None and not torch.compiler.is_compiling():
+ logger.warning_once(
+ "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
+ "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
+ )
+ position_ids = None
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
+
+ if inputs_embeds is None:
+ codebook_idxs = torch.clamp(cache_position - 1, min=0)
+ offset = codebook_idxs * self.vocab_size
+ inputs_embeds = self.embed_tokens(input_ids + offset)
+
+ input_ids_are_first_codebook = cache_position[0] == 0
+ if backbone_last_hidden_state is not None:
+ inputs_embeds[:, 0] = backbone_last_hidden_state
+ else:
+ if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
+ logger.warning(
+ "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
+ )
+
+ inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_ids = cache_position.unsqueeze(0)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class CsmCodebooksHead(nn.Module):
+ def __init__(self, hidden_size, num_codebooks, vocab_size):
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
+
+ def forward(self, hidden_states, cache_position=None):
+ if cache_position is None:
+ seq_length = hidden_states.shape[1]
+ codebook_weight = self.weight[torch.arange(seq_length)]
+ else:
+ codebook_idxs = cache_position - 1
+ codebook_weight = self.weight[codebook_idxs]
+
+ hidden_states = [
+ nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
+ for codebook_idx in range(codebook_weight.shape[0])
+ ]
+ hidden_states = torch.stack(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@add_start_docstrings(
+ """
+ The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
+ which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
+ (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
+ """,
+ CSM_DEPTH_DECODER_START_DOCSTRING,
+)
+class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = None
+ _tp_plan = None
+ _pp_plan = None
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = CsmDepthDecoderModel(config)
+ self.vocab_size = config.vocab_size
+ self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CsmDepthDecoderForCausalLM
+
+ >>> model = CsmDepthDecoderForCausalLM.from_pretrained("meta-csm_depth_decoder/CsmDepthDecoder-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-csm_depth_decoder/CsmDepthDecoder-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_state,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ if isinstance(logits_to_keep, int):
+ if logits_to_keep == 0:
+ # skip idx 0 logits since it's for the concatenated backbone last hidden state
+ slice_indices = slice(1, None)
+ else:
+ slice_indices = slice(-logits_to_keep, None)
+ else:
+ slice_indices = logits_to_keep
+
+ logits = self.codebooks_head(
+ hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
+ )
+ logits = logits.contiguous()
+
+ loss = None
+ if labels is not None:
+ shift_labels = labels[..., 1:].contiguous()
+ loss = self.loss_function(
+ logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
+ )
+
+ is_first_generation_step = model_inputs["cache_position"][0] == 0
+ if not is_first_generation_step:
+ model_inputs.pop("backbone_last_hidden_state")
+
+ # csm depth decoder does not use position_ids
+ model_inputs.pop("position_ids")
+
+ return model_inputs
+
+
+class CsmBackboneModelEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
+ self.register_buffer(
+ "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
+ )
+
+ def forward(self, input_ids):
+ input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
+ input_embeds = input_embeds.sum(dim=2)
+ return input_embeds
+
+
+INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
+ 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
+ requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
+
+ 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)"""
+
+
+CSM_BACKBONE_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING))
+
+
+@add_start_docstrings(
+ "The bare CsmBackboneModel Model outputting raw hidden-states without any specific head on top.",
+ CSM_START_DOCSTRING,
+)
+class CsmBackboneModel(CsmPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`]
+
+ Args:
+ config: CsmBackboneModelConfig
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = CsmBackboneModelEmbeddings(config)
+ self.layers = nn.ModuleList(
+ [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = CsmRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_BACKBONE_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
+ if not isinstance(past_key_values, (type(None), Cache)):
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+CSM_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING))
+
+
+@add_start_docstrings(
+ """
+ The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
+ """,
+ CSM_START_DOCSTRING,
+)
+class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
+ _tied_weights_keys = [
+ "backbone_model.embed_tokens.embed_audio_tokens.weight",
+ "depth_decoder.model.embed_tokens.weight",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
+ self.backbone_model = CsmBackboneModel._from_config(config)
+ self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
+ self.codec_model = AutoModel.from_config(config.codec_config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.backbone_model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _tie_weights(self):
+ if self.config.tie_codebooks_embeddings:
+ self._tie_or_clone_weights(
+ self.backbone_model.embed_tokens.embed_audio_tokens,
+ self.depth_decoder.model.embed_tokens,
+ )
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ if kwargs.get("output_loading_info", False):
+ model, loading_info = super().from_pretrained(*args, **kwargs)
+ else:
+ model = super().from_pretrained(*args, **kwargs)
+
+ # copy depth decoder generation conf attr to the depth decoder generation config
+ prefix = "depth_decoder_"
+ prefix_len = len(prefix)
+ depth_decoder_attrs = {
+ attr[prefix_len:]: value
+ for attr, value in vars(model.generation_config).items()
+ if attr.startswith(prefix)
+ }
+
+ vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
+
+ # remove the depth decoder generation conf attr from the model generation config
+ for attr in depth_decoder_attrs:
+ delattr(model.generation_config, prefix + attr)
+
+ if "output_loading_info" in kwargs:
+ return model, loading_info
+ else:
+ return model
+
+ def save_pretrained(self, *args, **kwargs):
+ # copy the depth decoder generation config attributes to the model generation config
+ prefix = "depth_decoder_"
+ depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
+ depth_decoder_attrs.pop("transformers_version", None)
+ for attr, value in depth_decoder_attrs.items():
+ setattr(self.generation_config, prefix + attr, value)
+
+ super().save_pretrained(*args, **kwargs)
+
+ def _merge_input_ids_with_input_values(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Optional[torch.Tensor]:
+ """
+ Merges the input_ids and input_values to produce a single inputs_embeds tensor:
+ 1 - Infers the codec model on the input_values to retreive codebook token.
+ 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
+ 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
+
+ Args:
+ input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The input ids to embed.
+ input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
+ The audio input values to embed.
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
+ The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
+ """
+ inputs_embeds = self.embed_text_tokens(input_ids)
+
+ if input_values is not None:
+ # infer input_values_mask
+ input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
+ audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
+ audio_lengths = audio_lengths[audio_lengths > 0]
+ input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
+ len(audio_lengths), -1
+ )
+ input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
+
+ # =======================================
+ # TODO: @eustlb, this should be batched !!!
+ # but requires making sure batched inference of the codec model works as intended
+ audio_tokens_list = []
+ for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
+ batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
+ for i in range(batch_input_values_cutoffs.shape[0] - 1):
+ start_idx = batch_input_values_cutoffs[i]
+ end_idx = batch_input_values_cutoffs[i + 1]
+ audio_batch = batch_input_values[..., start_idx:end_idx]
+ codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
+ codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
+ audio_tokens_list.append(codebook_ids[0])
+
+ max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
+ batched_audio_token_ids = torch.stack(
+ [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
+ )
+ audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
+ # =======================================
+ audio_token_id = self.config.audio_token_id
+ audio_token_mask = input_ids == audio_token_id
+
+ audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
+ inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
+
+ # same for the audio eos token
+ audio_eos_frame_ids = (
+ torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
+ * self.config.codebook_eos_token_id
+ )
+ audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
+
+ audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
+ inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
+
+ # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
+ if labels is not None:
+ labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
+ labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
+ # mask depth decoder
+ depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
+ labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
+ labels = labels_expanded
+
+ return {"inputs_embeds": inputs_embeds, "labels": labels}
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids=input_ids,
+ input_values=kwargs.get("input_values"),
+ input_values_cutoffs=kwargs.get("input_values_cutoffs"),
+ labels=kwargs.get("labels"),
+ )
+ model_inputs.update(
+ {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
+ )
+
+ return model_inputs
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CsmOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ input_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> Union[Tuple, CsmOutputWithPast]:
+ r"""
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
+ Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
+ If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
+ where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
+ the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
+ Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`.
+ - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
+ - `-100` will be ignored in the loss computation
+ - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
+
+ Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ Kept for compatibility. Does not support another value than:
+ 1. `0`, which is equivalent to keeping all logits, used in the training regime
+ 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import CsmForConditionalGeneration, AutoProcessor
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "eustlb/csm-1b"
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ >>> # ensure the audio is 24kHz
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+ >>> conversation = []
+ >>> # prepare a conversation with text and corresponding audio
+ >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ ... conversation.append(
+ ... {
+ ... "role": f"{speaker_id}",
+ ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ ... }
+ ... )
+
+ >>> inputs = processor.apply_chat_template(
+ ... conversation,
+ ... tokenize=True,
+ ... return_dict=True,
+ ... output_labels=True,
+ ... ).to(torch_device)
+
+ >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+ >>> output = model(**inputs)
+ >>> output.loss.backward()
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None and input_ids.ndim == 2:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids, input_values, input_values_cutoffs, labels
+ )
+ inputs_embeds = merged_inputs["inputs_embeds"]
+ labels = merged_inputs["labels"]
+ input_ids = None
+
+ backbone_outputs = self.backbone_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ backbone_hidden_states = backbone_outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
+
+ loss = None
+ backbone_loss = None
+ depth_decoder_loss = None
+ depth_decoder_outputs = None
+ if labels is not None:
+ # select first codebook as labels for the backbone model
+ backbone_labels = labels[:, :, 0]
+ backbone_loss = self.loss_function(
+ logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
+ )
+
+ # for the depth decoder, we need to select the frames to train on
+ # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
+ train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
+ depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
+ # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
+ depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
+
+ train_idxs = train_mask.nonzero(as_tuple=True)
+ backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
+ depth_decoder_labels = labels[train_mask]
+
+ depth_decoder_outputs = self.depth_decoder(
+ input_ids=depth_decoder_input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ labels=depth_decoder_labels,
+ )
+
+ depth_decoder_loss = depth_decoder_outputs.loss
+ loss = backbone_loss + depth_decoder_loss
+
+ return CsmOutputWithPast(
+ loss=loss,
+ backbone_loss=backbone_loss,
+ depth_decoder_loss=depth_decoder_loss,
+ logits=backbone_logits,
+ past_key_values=backbone_outputs.past_key_values,
+ hidden_states=backbone_outputs.hidden_states,
+ attentions=backbone_outputs.attentions,
+ depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
+ depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
+ )
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = [
+ "CsmPreTrainedModel",
+ "CsmBackboneModel",
+ "CsmDepthDecoderModel",
+ "CsmDepthDecoderForCausalLM",
+ "CsmForConditionalGeneration",
+]
diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py
new file mode 100644
index 0000000000..ed3d571034
--- /dev/null
+++ b/src/transformers/models/csm/modular_csm.py
@@ -0,0 +1,1042 @@
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ logging,
+ replace_return_docstrings,
+)
+from ..auto import AutoModel
+from ..llama.modeling_llama import (
+ KwargsForCausalLM,
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaMLP,
+ LlamaModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+)
+from .configuration_csm import (
+ CsmConfig,
+ CsmDepthDecoderConfig,
+)
+from .generation_csm import CsmGenerationMixin
+
+
+logger = logging.get_logger(__name__)
+_CONFIG_FOR_DOC = "CsmConfig"
+
+
+@dataclass
+class CsmOutputWithPast(ModelOutput):
+ """
+ Base class for the model autoregressive outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the depth decoder model.
+ depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
+ depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+ depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+ backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction) of the backbone model.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_loss: Optional[torch.FloatTensor] = None
+ depth_decoder_logits: torch.FloatTensor = None
+ depth_decoder_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ depth_decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ depth_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ backbone_loss: Optional[torch.FloatTensor] = None
+
+
+START_DOCSTRING_BASE = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`{config_class}`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+CSM_DEPTH_DECODER_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmDepthDecoderConfig"))
+
+
+CSM_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmConfig"))
+
+
+@add_start_docstrings(
+ "The bare Csm Model outputting raw hidden-states without any specific head on top.",
+ CSM_START_DOCSTRING,
+)
+class CsmPreTrainedModel(PreTrainedModel):
+ config_class = CsmConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["CsmDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # does not because of Mimi codec model
+ # _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, CsmCodebooksHead):
+ num_codebooks = module.num_codebooks
+ for i in range(num_codebooks - 1):
+ module.weight.data[i].normal_(mean=0.0, std=std)
+ elif isinstance(module, CsmRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+INPUTS_DOCSTRING_BASE = r"""
+ Args:
+ {input_ids_docstring}
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+DEPTH_DECODER_INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)"""
+
+
+INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
+ 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
+ requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
+
+ 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)"""
+
+
+CSM_DEPTH_DECODER_INPUTS_DOCSTRING = r"""{}""".format(
+ INPUTS_DOCSTRING_BASE.format(input_ids_docstring=DEPTH_DECODER_INPUT_IDS_DOCSTRING)
+)
+
+
+CSM_BACKBONE_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING))
+
+
+# manually specify names for correct naming when converting from modualr
+class CsmRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class CsmRotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class CsmMLP(LlamaMLP):
+ pass
+
+
+class CsmAttention(LlamaAttention):
+ pass
+
+
+class CsmDecoderLayer(LlamaDecoderLayer):
+ pass
+
+
+@add_start_docstrings(
+ "The bare CsmDepthDecoderModel outputting raw hidden-states without any specific head on top.",
+ CSM_DEPTH_DECODER_START_DOCSTRING,
+)
+class CsmDepthDecoderModel(LlamaModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`]
+
+ Args:
+ config: CsmDepthDecoderConfig
+ """
+
+ config_class = CsmDepthDecoderConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
+ self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ """
+ if position_ids is not None and not torch.compiler.is_compiling():
+ logger.warning_once(
+ "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
+ "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
+ )
+ position_ids = None
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
+ device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
+
+ if inputs_embeds is None:
+ codebook_idxs = torch.clamp(cache_position - 1, min=0)
+ offset = codebook_idxs * self.vocab_size
+ inputs_embeds = self.embed_tokens(input_ids + offset)
+
+ input_ids_are_first_codebook = cache_position[0] == 0
+ if backbone_last_hidden_state is not None:
+ inputs_embeds[:, 0] = backbone_last_hidden_state
+ else:
+ if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
+ logger.warning(
+ "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
+ )
+
+ inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_ids = cache_position.unsqueeze(0)
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class CsmCodebooksHead(nn.Module):
+ def __init__(self, hidden_size, num_codebooks, vocab_size):
+ super().__init__()
+ self.num_codebooks = num_codebooks
+ self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
+
+ def forward(self, hidden_states, cache_position=None):
+ if cache_position is None:
+ seq_length = hidden_states.shape[1]
+ codebook_weight = self.weight[torch.arange(seq_length)]
+ else:
+ codebook_idxs = cache_position - 1
+ codebook_weight = self.weight[codebook_idxs]
+
+ hidden_states = [
+ nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
+ for codebook_idx in range(codebook_weight.shape[0])
+ ]
+ hidden_states = torch.stack(hidden_states, dim=1)
+
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
+ which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
+ (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
+ """,
+ CSM_DEPTH_DECODER_START_DOCSTRING,
+)
+class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
+ _tied_weights_keys = None
+ _tp_plan = None
+ _pp_plan = None
+
+ def __init__(self, config):
+ super().__init__(config)
+ del self.lm_head
+ self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
+ self.model = CsmDepthDecoderModel(config)
+
+ def get_output_embeddings(self):
+ raise AttributeError("Not needed for Csm")
+
+ def set_output_embeddings(self, new_embeddings):
+ raise AttributeError("Not needed for Csm")
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
+ )
+
+ is_first_generation_step = model_inputs["cache_position"][0] == 0
+ if not is_first_generation_step:
+ model_inputs.pop("backbone_last_hidden_state")
+
+ # csm depth decoder does not use position_ids
+ model_inputs.pop("position_ids")
+
+ return model_inputs
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
+ The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
+ is provided in the `input_ids` argument.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_state,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ if isinstance(logits_to_keep, int):
+ if logits_to_keep == 0:
+ # skip idx 0 logits since it's for the concatenated backbone last hidden state
+ slice_indices = slice(1, None)
+ else:
+ slice_indices = slice(-logits_to_keep, None)
+ else:
+ slice_indices = logits_to_keep
+
+ logits = self.codebooks_head(
+ hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
+ )
+ logits = logits.contiguous()
+
+ loss = None
+ if labels is not None:
+ shift_labels = labels[..., 1:].contiguous()
+ loss = self.loss_function(
+ logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class CsmBackboneModelEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
+ self.register_buffer(
+ "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
+ )
+
+ def forward(self, input_ids):
+ input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
+ input_embeds = input_embeds.sum(dim=2)
+ return input_embeds
+
+
+@add_start_docstrings(
+ "The bare CsmBackboneModel Model outputting raw hidden-states without any specific head on top.",
+ CSM_START_DOCSTRING,
+)
+class CsmBackboneModel(LlamaModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`]
+
+ Args:
+ config: CsmBackboneModelConfig
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.embed_tokens = CsmBackboneModelEmbeddings(config)
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_BACKBONE_INPUTS_DOCSTRING)
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+CSM_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING))
+
+
+@add_start_docstrings(
+ """
+ The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
+ """,
+ CSM_START_DOCSTRING,
+)
+class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
+ _tied_weights_keys = [
+ "backbone_model.embed_tokens.embed_audio_tokens.weight",
+ "depth_decoder.model.embed_tokens.weight",
+ ]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
+ self.backbone_model = CsmBackboneModel._from_config(config)
+ self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
+ self.codec_model = AutoModel.from_config(config.codec_config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.backbone_model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.backbone_model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _tie_weights(self):
+ if self.config.tie_codebooks_embeddings:
+ self._tie_or_clone_weights(
+ self.backbone_model.embed_tokens.embed_audio_tokens,
+ self.depth_decoder.model.embed_tokens,
+ )
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ if kwargs.get("output_loading_info", False):
+ model, loading_info = super().from_pretrained(*args, **kwargs)
+ else:
+ model = super().from_pretrained(*args, **kwargs)
+
+ # copy depth decoder generation conf attr to the depth decoder generation config
+ prefix = "depth_decoder_"
+ prefix_len = len(prefix)
+ depth_decoder_attrs = {
+ attr[prefix_len:]: value
+ for attr, value in vars(model.generation_config).items()
+ if attr.startswith(prefix)
+ }
+
+ vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
+
+ # remove the depth decoder generation conf attr from the model generation config
+ for attr in depth_decoder_attrs:
+ delattr(model.generation_config, prefix + attr)
+
+ if "output_loading_info" in kwargs:
+ return model, loading_info
+ else:
+ return model
+
+ def save_pretrained(self, *args, **kwargs):
+ # copy the depth decoder generation config attributes to the model generation config
+ prefix = "depth_decoder_"
+ depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
+ depth_decoder_attrs.pop("transformers_version", None)
+ for attr, value in depth_decoder_attrs.items():
+ setattr(self.generation_config, prefix + attr, value)
+
+ super().save_pretrained(*args, **kwargs)
+
+ def _merge_input_ids_with_input_values(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ input_values: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Optional[torch.Tensor]:
+ """
+ Merges the input_ids and input_values to produce a single inputs_embeds tensor:
+ 1 - Infers the codec model on the input_values to retreive codebook token.
+ 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
+ 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
+
+ Args:
+ input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
+ The input ids to embed.
+ input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
+ The audio input values to embed.
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
+ The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
+ """
+ inputs_embeds = self.embed_text_tokens(input_ids)
+
+ if input_values is not None:
+ # infer input_values_mask
+ input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
+ audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
+ audio_lengths = audio_lengths[audio_lengths > 0]
+ input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
+ len(audio_lengths), -1
+ )
+ input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
+
+ # =======================================
+ # TODO: @eustlb, this should be batched !!!
+ # but requires making sure batched inference of the codec model works as intended
+ audio_tokens_list = []
+ for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
+ batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
+ for i in range(batch_input_values_cutoffs.shape[0] - 1):
+ start_idx = batch_input_values_cutoffs[i]
+ end_idx = batch_input_values_cutoffs[i + 1]
+ audio_batch = batch_input_values[..., start_idx:end_idx]
+ codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
+ codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
+ audio_tokens_list.append(codebook_ids[0])
+
+ max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
+ batched_audio_token_ids = torch.stack(
+ [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
+ )
+ audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
+ # =======================================
+ audio_token_id = self.config.audio_token_id
+ audio_token_mask = input_ids == audio_token_id
+
+ audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
+ inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
+
+ # same for the audio eos token
+ audio_eos_frame_ids = (
+ torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
+ * self.config.codebook_eos_token_id
+ )
+ audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
+
+ audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
+ inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
+
+ # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
+ if labels is not None:
+ labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
+ labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
+ # mask depth decoder
+ depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
+ labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
+ labels = labels_expanded
+
+ return {"inputs_embeds": inputs_embeds, "labels": labels}
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[Cache] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids=input_ids,
+ input_values=kwargs.get("input_values"),
+ input_values_cutoffs=kwargs.get("input_values_cutoffs"),
+ labels=kwargs.get("labels"),
+ )
+ model_inputs.update(
+ {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
+ )
+
+ return model_inputs
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(CSM_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CsmOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ input_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ input_values_cutoffs: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> Union[Tuple, CsmOutputWithPast]:
+ r"""
+ input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
+ Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
+ If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
+ where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
+ the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
+ Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`.
+ - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
+ - `-100` will be ignored in the loss computation
+ - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
+
+ Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ Kept for compatibility. Does not support another value than:
+ 1. `0`, which is equivalent to keeping all logits, used in the training regime
+ 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import CsmForConditionalGeneration, AutoProcessor
+ >>> from datasets import load_dataset, Audio
+
+ >>> model_id = "eustlb/csm-1b"
+ >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ >>> # ensure the audio is 24kHz
+ >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
+
+ >>> conversation = []
+ >>> # prepare a conversation with text and corresponding audio
+ >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ ... conversation.append(
+ ... {
+ ... "role": f"{speaker_id}",
+ ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ ... }
+ ... )
+
+ >>> inputs = processor.apply_chat_template(
+ ... conversation,
+ ... tokenize=True,
+ ... return_dict=True,
+ ... output_labels=True,
+ ... ).to(torch_device)
+
+ >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
+ >>> output = model(**inputs)
+ >>> output.loss.backward()
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None and input_ids.ndim == 2:
+ merged_inputs = self._merge_input_ids_with_input_values(
+ input_ids, input_values, input_values_cutoffs, labels
+ )
+ inputs_embeds = merged_inputs["inputs_embeds"]
+ labels = merged_inputs["labels"]
+ input_ids = None
+
+ backbone_outputs = self.backbone_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ backbone_hidden_states = backbone_outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
+
+ loss = None
+ backbone_loss = None
+ depth_decoder_loss = None
+ depth_decoder_outputs = None
+ if labels is not None:
+ # select first codebook as labels for the backbone model
+ backbone_labels = labels[:, :, 0]
+ backbone_loss = self.loss_function(
+ logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
+ )
+
+ # for the depth decoder, we need to select the frames to train on
+ # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
+ train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
+ depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
+ # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
+ depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
+
+ train_idxs = train_mask.nonzero(as_tuple=True)
+ backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
+ depth_decoder_labels = labels[train_mask]
+
+ depth_decoder_outputs = self.depth_decoder(
+ input_ids=depth_decoder_input_ids,
+ backbone_last_hidden_state=backbone_last_hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ labels=depth_decoder_labels,
+ )
+
+ depth_decoder_loss = depth_decoder_outputs.loss
+ loss = backbone_loss + depth_decoder_loss
+
+ return CsmOutputWithPast(
+ loss=loss,
+ backbone_loss=backbone_loss,
+ depth_decoder_loss=depth_decoder_loss,
+ logits=backbone_logits,
+ past_key_values=backbone_outputs.past_key_values,
+ hidden_states=backbone_outputs.hidden_states,
+ attentions=backbone_outputs.attentions,
+ depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
+ depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
+ if depth_decoder_outputs is not None
+ else None,
+ depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
+ )
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = [
+ "CsmPreTrainedModel",
+ "CsmBackboneModel",
+ "CsmDepthDecoderModel",
+ "CsmDepthDecoderForCausalLM",
+ "CsmForConditionalGeneration",
+]
diff --git a/src/transformers/models/csm/processing_csm.py b/src/transformers/models/csm/processing_csm.py
new file mode 100644
index 0000000000..486c5eda4c
--- /dev/null
+++ b/src/transformers/models/csm/processing_csm.py
@@ -0,0 +1,364 @@
+# coding=utf-8
+# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+
+from ...utils import is_soundfile_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+if is_soundfile_available():
+ import soundfile as sf
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...feature_extraction_utils import BatchFeature
+from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import (
+ PreTokenizedInput,
+ TextInput,
+)
+
+
+class CsmAudioKwargs(AudioKwargs, total=False):
+ encoded_length_kwargs: Optional[Dict[str, Any]]
+
+
+class CsmProcessorKwargs(ProcessingKwargs, total=False):
+ audio_kwargs: CsmAudioKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": True,
+ "padding_side": "left",
+ "add_special_tokens": False,
+ },
+ "audio_kwargs": {
+ "encoded_length_kwargs": {
+ "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
+ "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
+ "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ "use_causal_conv": True,
+ },
+ "sampling_rate": 24000,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class CsmProcessor(ProcessorMixin):
+ r"""
+ Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and
+ [`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
+ tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more
+ information.
+ The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
+ ```python
+ from transformers import CsmProcessor
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ audio = ds[0]["audio"]["array"]
+
+ processor = CsmProcessor.from_pretrained("eustlb/csm-1b")
+
+ processor(
+ text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"],
+ audio=audio,
+ text_kwargs = {"padding": False},
+ audio_kwargs = {"sampling_rate": 16000},
+ common_kwargs = {"return_tensors": "pt"},
+ )
+ # this should error out because EncodecFeatureExtractor expects a 24kHz audio :)
+ ```
+
+ Args:
+ feature_extractor ([`EncodecFeatureExtractor`]):
+ The feature extractor is a required input.
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+
+ """
+
+ attributes = ["feature_extractor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ feature_extractor_class = "EncodecFeatureExtractor"
+ tokenizer_class = "PreTrainedTokenizerFast"
+
+ def __init__(
+ self,
+ feature_extractor,
+ tokenizer,
+ chat_template=None,
+ ):
+ if not hasattr(tokenizer, "audio_token"):
+ self.audio_token = "<|AUDIO|>"
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
+ else:
+ self.audio_token = tokenizer.audio_token
+ self.audio_token_id = tokenizer.audio_token_id
+
+ if not hasattr(tokenizer, "audio_eos_token"):
+ self.audio_eos_token = "<|audio_eos|>"
+ self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
+ else:
+ self.audio_eos_token = tokenizer.audio_eos_token
+ self.audio_eos_token_id = tokenizer.audio_eos_token_id
+
+ super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
+
+ @staticmethod
+ def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
+ """
+ Compute the length of the encoded audio sequence.
+
+ Args:
+ audio_length (int): The length of the audio sequence.
+ kernel_sizes (List[int]): The kernel sizes for the convolutional layers.
+ strides (List[int]): The strides for the convolutional layers.
+ use_causal_conv (bool): Whether to use causal convolutions.
+ """
+ cur_length = audio_length
+
+ if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
+ return cur_length
+
+ for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
+ effective_kernel_size = (kernel_size - 1) * dilation + 1
+ padding_total = kernel_size - stride
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+
+ n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
+ n_frames = math.ceil(n_frames) - 1
+ ideal_length = n_frames * stride + kernel_size - padding_total
+ extra_padding = ideal_length - cur_length
+
+ if use_causal_conv:
+ padding_left = padding_total
+ padding_right = extra_padding
+ else:
+ padding_left = padding_left
+ padding_right = padding_right + extra_padding
+
+ cur_length = cur_length + padding_left + padding_right
+ cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
+
+ return cur_length
+
+ def save_audio(
+ self,
+ audio: AudioInput,
+ saving_path: Union[str, Path, List[Union[str, Path]]],
+ **kwargs: Unpack[CsmProcessorKwargs],
+ ):
+ # TODO: @eustlb, this should be in AudioProcessor
+ if not is_soundfile_available():
+ raise ImportError("Please install `soundfile` to save audio files.")
+
+ # ensure correct audio input
+ audio = make_list_of_audio(audio)
+
+ # ensure correct saving path
+ if isinstance(saving_path, (str, Path)):
+ saving_path = [saving_path]
+ elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
+ raise ValueError("Invalid input path. Please provide a string, or a list of strings")
+
+ if len(audio) != len(saving_path):
+ raise ValueError("The number of audio and saving paths must be the same")
+
+ output_kwargs = self._merge_kwargs(
+ CsmProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ sampling_rate = audio_kwargs["sampling_rate"]
+
+ for audio_value, p in zip(audio, saving_path):
+ if isinstance(audio_value, torch.Tensor):
+ audio_value = audio_value.cpu().float().numpy()
+ sf.write(p, audio_value, sampling_rate)
+
+ def __call__(
+ self,
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]],
+ audio: Optional[AudioInput] = None,
+ output_labels: Optional[bool] = False,
+ depth_decoder_labels_ratio: Optional[float] = 1.0,
+ **kwargs: Unpack[CsmProcessorKwargs],
+ ):
+ r"""
+ Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text`
+ arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode
+ the text. To prepare the audio, this method forwards the `audio` arguments to
+ EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer
+ to the docstring of the above two methods for more information.
+
+ Args:
+ audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
+ tensor.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ output_labels (bool, *optional*, default=False):
+ Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
+ - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
+ - `-100` will be ignored in the loss computation
+ - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
+ depth_decoder_labels_ratio (float, *optional*, default=1.0):
+ The ratio of audio frames to keep for the depth decoder labels.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
+ """
+
+ output_kwargs = self._merge_kwargs(
+ CsmProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ text_kwargs = output_kwargs["text_kwargs"]
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ common_kwargs = output_kwargs["common_kwargs"]
+
+ return_tensors = common_kwargs.pop("return_tensors", None)
+ if return_tensors != "pt":
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
+
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+ n_audio_in_text = [t.count(self.audio_token) for t in text]
+
+ n_audio = 0
+ if audio is not None:
+ audio = make_list_of_audio(audio)
+ n_audio = len(audio)
+
+ if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
+ if audio is None:
+ raise ValueError("No audio were provided, but there are audio tokens in the prompt")
+ else:
+ raise ValueError(
+ f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
+ f"number of provided audios ({n_audio})."
+ )
+
+ if audio is not None:
+ encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
+ num_audio_tokens_list = [
+ self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
+ ]
+ num_audio_tokens_list_copy = num_audio_tokens_list.copy()
+
+ # expand the text to repeat the audio token for the corresponding number of frames
+ expanded_text = []
+ for sample in text:
+ replace_str = []
+ while self.audio_token in sample:
+ num_audio_tokens = num_audio_tokens_list_copy.pop(0)
+ expanded_audio_token = self.audio_token * num_audio_tokens
+
+ replace_str.append(expanded_audio_token)
+ sample = sample.replace(self.audio_token, "", 1)
+
+ while "" in sample:
+ sample = sample.replace("", replace_str.pop(0), 1)
+ expanded_text.append(sample)
+
+ text = expanded_text
+
+ encoding = self.tokenizer(text, **text_kwargs)
+ data = {}
+ data.update(encoding)
+
+ if audio is not None:
+ audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
+
+ concatenated_audio, input_values_cutoffs = [], []
+ offset = 0
+ for n_audio in n_audio_in_text:
+ if n_audio == 0:
+ concatenated_audio.append(np.zeros(0))
+ input_values_cutoffs.append(torch.tensor([-1]))
+ else:
+ concatenated_audio.append(
+ np.concatenate(
+ [
+ el.cpu().numpy() if isinstance(el, torch.Tensor) else el
+ for el in audio[offset : offset + n_audio]
+ ],
+ axis=-1,
+ )
+ )
+ input_values_cutoffs.append(
+ torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
+ )
+ offset += n_audio
+
+ audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
+ audio_inputs.pop("padding_mask", None) # not applicable here
+ data.update(audio_inputs)
+
+ # pad and stack the audio cut idxs
+ max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
+ input_values_cutoffs = [
+ torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
+ for cut_idxs in input_values_cutoffs
+ ]
+ data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
+
+ if output_labels:
+ audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
+ n_audio_frames = audio_frame_idxs.shape[0]
+
+ if depth_decoder_labels_ratio <= 1.0:
+ rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
+ skip_frames_idxs = audio_frame_idxs[rand_idxs]
+ else:
+ skip_frames_idxs = audio_frame_idxs
+
+ labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
+ labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
+
+ data["labels"] = labels
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["CsmProcessor"]
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index 03550e09ed..1ba8536839 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -1111,7 +1111,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
)
return model_inputs
- def _get_initial_cache_position(self, input_ids, model_kwargs):
+ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
"""
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
@@ -1125,8 +1125,8 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
- cur_len = input_ids.shape[-1]
- model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
+ cur_len = seq_length
+ model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device)
return model_kwargs
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py
index 6a30954345..853e371c89 100644
--- a/src/transformers/models/janus/modeling_janus.py
+++ b/src/transformers/models/janus/modeling_janus.py
@@ -1563,7 +1563,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
inputs_embeds = self.get_input_embeddings()(input_tokens)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.
diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py
index 36499e43ea..50cb0021bf 100644
--- a/src/transformers/models/janus/modular_janus.py
+++ b/src/transformers/models/janus/modular_janus.py
@@ -1378,7 +1378,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
inputs_embeds = self.get_input_embeddings()(input_tokens)
- model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.
diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py
index b8f24040dd..bfe6698c55 100644
--- a/src/transformers/models/mimi/modeling_mimi.py
+++ b/src/transformers/models/mimi/modeling_mimi.py
@@ -216,6 +216,32 @@ class MimiConv1d(nn.Module):
end = padded.shape[-1] - extra_pad
return padded[..., :end]
+ def _get_output_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
+ """
+ Return the length of the output of the MimiConv1d.
+ """
+ # padding size
+ n_frames = (input_length - self.kernel_size + self.padding_total) / self.stride + 1
+ n_frames = torch.ceil(n_frames).to(torch.int64) - 1
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
+ extra_padding = ideal_length - input_length
+
+ if self.causal:
+ padding_left = self.padding_total
+ padding_right = extra_padding
+ else:
+ padding_left = self.padding_left
+ padding_right = self.padding_right + extra_padding
+
+ # padding
+ input_length = input_length + padding_left + padding_right
+
+ # conv
+ output_lenght = (
+ input_length + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1
+ ) // self.conv.stride[0] + 1
+ return output_lenght
+
def forward(self, hidden_states):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
@@ -331,21 +357,28 @@ class MimiEncoder(nn.Module):
model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
scaling = 1
+ # keep track of MimiConv1d submodule layer names for easy encoded length computation
+ mimiconv1d_layer_names = ["layers.0"]
+
# Downsample to raw audio scale
for ratio in reversed(config.upsampling_ratios):
current_scale = scaling * config.num_filters
# Add residual layers
for j in range(config.num_residual_layers):
+ mimiconv1d_layer_names.extend([f"layers.{len(model)}.block.1", f"layers.{len(model)}.block.3"])
model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
# Add downsampling layers
model += [nn.ELU()]
+ mimiconv1d_layer_names.append(f"layers.{len(model)}")
model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
scaling *= 2
model += [nn.ELU()]
+ mimiconv1d_layer_names.append(f"layers.{len(model)}")
model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
self.layers = nn.ModuleList(model)
+ self._mimiconv1d_layer_names = mimiconv1d_layer_names
# Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
def forward(self, hidden_states):
@@ -1567,6 +1600,38 @@ class MimiModel(MimiPreTrainedModel):
codes = codes.transpose(0, 1)
return codes, past_key_values
+ def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
+ """
+ Return the number of frames of the encoded audio waveform.
+ """
+ output_length = input_length
+
+ # encoder
+ for layer_name in self.encoder._mimiconv1d_layer_names:
+ output_length = self.encoder.get_submodule(layer_name)._get_output_length(output_length)
+
+ # downsample
+ output_length = self.downsample._get_output_length(output_length)
+
+ return output_length
+
+ def get_audio_codes_mask(self, padding_mask: torch.Tensor, padding_side: str = "right"):
+ """
+ Get the mask for the audio codes from the original padding mask.
+ """
+ encoded_lengths = self.get_encoded_length(padding_mask.sum(dim=-1))
+
+ audio_codes_mask = torch.arange(encoded_lengths.max(), device=encoded_lengths.device).expand(
+ len(encoded_lengths), -1
+ )
+ audio_codes_mask = audio_codes_mask < encoded_lengths.unsqueeze(1)
+ audio_codes_mask = audio_codes_mask.to(padding_mask.device)
+
+ if padding_side == "right":
+ return audio_codes_mask
+ else:
+ return audio_codes_mask.flip(dims=[-1])
+
def encode(
self,
input_values: torch.Tensor,
diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
index 19c88c047a..9737e1437e 100644
--- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
@@ -3084,10 +3084,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
thinker_reply_part=thinker_reply_part,
)
- def _get_initial_cache_position(self, input_ids, model_kwargs):
+ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
inputs_embeds = model_kwargs.pop("inputs_embeds")
- model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
model_kwargs["inputs_embeds"] = inputs_embeds
return model_kwargs
diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
index 2123be2903..9b5c416764 100644
--- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
@@ -2771,10 +2771,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
thinker_reply_part=thinker_reply_part,
)
- def _get_initial_cache_position(self, input_ids, model_kwargs):
+ def _get_initial_cache_position(self, seq_length, device, model_kwargs):
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
inputs_embeds = model_kwargs.pop("inputs_embeds")
- model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
+ model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
model_kwargs["inputs_embeds"] = inputs_embeds
return model_kwargs
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index 403569ae89..d07dad0205 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -1058,7 +1058,7 @@ class ProcessorMixin(PushToHubMixin):
# update defaults with arguments from tokenizer init
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
# init with tokenizer init kwargs if necessary
- if modality_key in tokenizer_init_kwargs:
+ if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
value = (
getattr(self.tokenizer, modality_key)
if hasattr(self.tokenizer, modality_key)
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index bb4965cb76..20edc1c897 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -501,9 +501,9 @@ class GenerationTesterMixin:
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
@@ -525,13 +525,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
@@ -565,10 +565,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@@ -582,9 +582,9 @@ class GenerationTesterMixin:
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_sample_generate_dict_output(self):
@@ -607,13 +607,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
@@ -632,9 +632,9 @@ class GenerationTesterMixin:
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_beam_search_generate_dict_output(self):
@@ -657,13 +657,13 @@ class GenerationTesterMixin:
use_cache=False,
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -706,10 +706,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(
@@ -759,9 +759,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
@@ -786,13 +786,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -840,9 +840,9 @@ class GenerationTesterMixin:
beam_kwargs=beam_kwargs,
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
# check `group_beam_search` for higher than 1 `num_return_sequences`
num_return_sequences = 2
@@ -853,9 +853,9 @@ class GenerationTesterMixin:
beam_kwargs=beam_kwargs,
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
@@ -878,13 +878,13 @@ class GenerationTesterMixin:
use_cache=False,
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -923,9 +923,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@@ -947,9 +947,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@@ -987,13 +987,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -1031,9 +1031,9 @@ class GenerationTesterMixin:
use_cache=True, # Enable cache
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
- self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
+ self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_contrastive_generate_dict_outputs_use_cache(self):
@@ -1067,10 +1067,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@@ -1499,7 +1499,7 @@ class GenerationTesterMixin:
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
- cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
+ cache_position = torch.arange(input_ids.shape[1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
@@ -1525,10 +1525,12 @@ class GenerationTesterMixin:
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
- pad_size = (input_ids.shape[0], 32)
+ pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
- padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
+ padded_attention_mask = torch.cat(
+ (torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
+ )
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
@@ -1587,7 +1589,7 @@ class GenerationTesterMixin:
else text_config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
- batch_size, seq_length = inputs["decoder_input_ids"].shape
+ batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
default_cross_attention_shape = (
@@ -1606,7 +1608,7 @@ class GenerationTesterMixin:
for _ in range(num_decoder_layers)
]
else:
- batch_size, seq_length = inputs["input_ids"].shape
+ batch_size, seq_length = inputs["input_ids"].shape[:2]
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
all_cache_shapes = [
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
@@ -1727,7 +1729,7 @@ class GenerationTesterMixin:
"min_new_tokens": 5, # generate exactly 5 tokens
}
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
- self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
+ self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same.
@@ -2262,11 +2264,11 @@ class GenerationTesterMixin:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
if model.config.is_encoder_decoder:
- self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
+ self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
else:
self.assertTrue(
- output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
+ output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
@@ -2408,7 +2410,7 @@ class GenerationTesterMixin:
config = config.text_config if hasattr(config, "text_config") else config
generated_length = (
- output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
+ output.sequences.shape[1] - 1 if config.is_encoder_decoder else output.sequences.shape[1] - prompt_length
)
decoder_past_key_values = getattr(output, "past_key_values", None)
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
@@ -2441,7 +2443,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
attentions=output.decoder_attentions,
prompt_length=1, # the BOS token
- output_length=output.sequences.shape[-1],
+ output_length=output.sequences.shape[1],
config=config,
decoder_past_key_values=decoder_past_key_values,
)
@@ -2450,7 +2452,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
attentions=output.attentions,
prompt_length=prompt_length,
- output_length=output.sequences.shape[-1],
+ output_length=output.sequences.shape[1],
config=config,
decoder_past_key_values=decoder_past_key_values,
)
@@ -2469,7 +2471,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
hidden_states=output.decoder_hidden_states,
prompt_length=1, # the BOS token
- output_length=output.sequences.shape[-1],
+ output_length=output.sequences.shape[1],
config=config,
use_cache=use_cache,
)
@@ -2478,7 +2480,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
hidden_states=output.hidden_states,
prompt_length=prompt_length,
- output_length=output.sequences.shape[-1],
+ output_length=output.sequences.shape[1],
config=config,
use_cache=use_cache,
)
@@ -2506,7 +2508,7 @@ class GenerationTesterMixin:
)
if has_standard_cache:
if use_cache:
- cache_length = output.sequences.shape[-1] - 1
+ cache_length = output.sequences.shape[1] - 1
self._check_past_key_values_for_generate(
batch_size=internal_batch_size,
decoder_past_key_values=decoder_past_key_values,
diff --git a/tests/models/csm/__init__.py b/tests/models/csm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/csm/test_modeling_csm.py b/tests/models/csm/test_modeling_csm.py
new file mode 100644
index 0000000000..93423598ba
--- /dev/null
+++ b/tests/models/csm/test_modeling_csm.py
@@ -0,0 +1,693 @@
+# coding=utf-8
+# Copyright 2024, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch ConversationalSpeechModel model."""
+
+import collections
+import copy
+import re
+import unittest
+
+import pytest
+from parameterized import parameterized
+
+from transformers import (
+ AutoProcessor,
+ CsmConfig,
+ CsmForConditionalGeneration,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ cleanup,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+from transformers.utils.import_utils import is_datasets_available
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ ids_tensor,
+)
+
+
+if is_datasets_available():
+ from datasets import load_dataset
+
+if is_torch_available():
+ import torch
+
+ from transformers.pytorch_utils import id_tensor_storage
+
+
+class CsmModelTester:
+ def __init__(
+ self,
+ parent,
+ ignore_index=-100,
+ batch_size=3,
+ seq_length=7,
+ is_training=True,
+ depth_decoder_config={
+ "num_codebooks": 10,
+ "backbone_hidden_size": 64,
+ "vocab_size": 6,
+ "hidden_size": 64,
+ "intermediate_size": 128,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 2,
+ "hidden_act": "silu",
+ "max_position_embeddings": 10,
+ },
+ codec_config={
+ "model_type": "mimi",
+ "audio_channels": 1,
+ "chunk_in_sec": None,
+ "hidden_size": 32,
+ "num_filters": 8,
+ "num_residual_layers": 1,
+ "upsampling_ratios": [8, 4],
+ "codebook_size": 64,
+ "vector_quantization_hidden_dimension": 64,
+ "upsample_groups": 32,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "sliding_window": 4,
+ "codebook_dim": 64,
+ "use_cache": False,
+ },
+ config={
+ "num_codebooks": 10,
+ "vocab_size": 6,
+ "text_vocab_size": 99,
+ "hidden_size": 64,
+ "intermediate_size": 64,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 2,
+ "hidden_act": "silu",
+ "max_position_embeddings": 10,
+ "bos_token_id": 1,
+ "pad_token_id": 2,
+ "eos_token_id": 3,
+ "codebook_pad_token_id": 2,
+ "codebook_eos_token_id": 3,
+ },
+ ):
+ self.parent = parent
+ self.is_training = is_training
+ self.ignore_index = ignore_index
+ self.depth_decoder_config = depth_decoder_config
+ self.codec_config = codec_config
+ self.config = config
+ self.seq_length = seq_length
+ self.batch_size = batch_size
+
+ self.num_hidden_layers = config["num_hidden_layers"]
+ self.vocab_size = config["vocab_size"]
+ self.hidden_size = config["hidden_size"]
+ self.num_attention_heads = config["num_attention_heads"]
+ self.pad_token_id = config["pad_token_id"]
+
+ def get_config(self):
+ return CsmConfig(
+ depth_decoder_config=self.depth_decoder_config,
+ codec_config=self.codec_config,
+ **self.config,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ input_ids = ids_tensor([self.batch_size, self.seq_length, config.num_codebooks], config.vocab_size - 1) + 1
+ attention_mask = input_ids[..., -1].ne(1).to(torch_device)
+ return config, input_ids, attention_mask
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_ids, attention_mask = self.prepare_config_and_inputs()
+ inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
+ return config, inputs_dict
+
+
+class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (CsmForConditionalGeneration,) if is_torch_available() else ()
+ test_pruning = False
+ test_headmasking = False
+ test_resize_embeddings = False
+ test_resize_embeddings_untied = False
+ test_torch_exportable = True
+
+ def setUp(self):
+ self.model_tester = CsmModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=CsmConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ """
+ Overrides [ModelTesterMixin._prepare_for_class] to handle third input_ids dimension.
+ """
+ inputs_dict = copy.deepcopy(inputs_dict)
+
+ if return_labels:
+ inputs_dict["labels"] = torch.zeros(
+ (
+ self.model_tester.batch_size,
+ self.model_tester.seq_length,
+ self.model_tester.config["num_codebooks"],
+ ),
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ return inputs_dict
+
+ def _get_logits_processor_kwargs(self, do_sample=False, config=None):
+ """
+ Overrides [GenerationTesterMixin._get_logits_processor_kwargs] to restrict to top_k, top_p, and temperature sampling.
+ """
+ logits_processor_kwargs = {}
+ if do_sample:
+ logits_processor_kwargs.update(
+ {
+ "top_k": 10,
+ "top_p": 0.7,
+ "temperature": 0.7,
+ }
+ )
+
+ return logits_processor_kwargs
+
+ def test_initialization(self):
+ """
+ Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
+ See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = ["conv", "input_proj", "output_proj"]
+ if param.requires_grad:
+ if any(x in name for x in uniform_init_parms):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
+ """
+ Overrides [GenerationTesterMixin._check_similar_generate_outputs] to handle third input_ids dimension.
+ Here we only look a the first codebook (index 0 on last dimension of the generated sequences) since returned scores
+ are for this token.
+ """
+ # scores doesn't include data regarding decoder input tokens
+ decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
+ output_matches = output_1.sequences[..., 0] == output_2.sequences[..., 0]
+ has_matching_outputs = output_matches.all()
+ has_matching_scores = None
+ if not has_matching_outputs:
+ for batch_idx in range(output_1.sequences.shape[0]):
+ batch_matches = output_matches[batch_idx]
+ if batch_matches.all():
+ continue
+ first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
+ first_mismatch_idx -= decoder_input_length
+ output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
+ output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
+ has_matching_scores = torch.allclose(
+ output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
+ )
+ if not has_matching_scores:
+ break
+ self.assertTrue(has_matching_outputs or has_matching_scores)
+
+ @parameterized.expand([("random",), ("same",)])
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support assisted decoding.")
+ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support assisted decoding.")
+ def test_assisted_decoding_sample(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support Dola decoding.")
+ def test_dola_decoding_sample(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_beam_sample_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_beam_search_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_beam_search_generate_dict_output(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_beam_search_generate_dict_outputs_use_cache(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_beam_sample_generate_dict_output(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support group beam search.")
+ def test_group_beam_search_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support group beam search.")
+ def test_group_beam_search_generate_dict_output(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support constrained beam search.")
+ def test_constrained_beam_search_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support constrained beam search.")
+ def test_constrained_beam_search_generate_dict_output(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support contrastive search.")
+ def test_contrastive_generate(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support contrastive search.")
+ def test_contrastive_generate_dict_outputs_use_cache(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support contrastive search.")
+ def test_contrastive_generate_low_memory(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support prompt lookup decoding.")
+ def test_prompt_lookup_decoding_matches_greedy_search(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support prompt lookup decoding.")
+ def test_prompt_lookup_decoding_stops_at_eos(self):
+ pass
+
+ @pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
+ def test_tie_model_weights(self):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip(reason="CSM does not support beam search.")
+ def test_model_parallel_beam_search(self):
+ pass
+
+ def test_tied_weights_keys(self):
+ """
+ Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
+ """
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model_tied = model_class(config)
+
+ ptrs = collections.defaultdict(list)
+ for name, tensor in model_tied.state_dict().items():
+ ptrs[id_tensor_storage(tensor)].append(name)
+
+ # These are all the pointers of shared tensors.
+ tied_params = [names for _, names in ptrs.items() if len(names) > 1]
+
+ tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
+ # Detect we get a hit for each key
+ for key in tied_weight_keys:
+ is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
+ self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
+
+ # Removed tied weights found from tied params -> there should only be one left after
+ for key in tied_weight_keys:
+ for i in range(len(tied_params)):
+ tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
+
+ tied_params = [group for group in tied_params if len(group) > 1]
+ self.assertListEqual(
+ tied_params,
+ [],
+ f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
+ )
+
+ def _get_custom_4d_mask_test_data(self):
+ """
+ Overrides [ModelTesterMixin._get_custom_4d_mask_test_data] to handle third input_ids dimension.
+ """
+ # Sequence in which all but the last token is the same
+ input_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 4], [0, 1, 2, 5]], device=torch_device, dtype=torch.int64)
+ input_ids = input_ids.unsqueeze(-1).expand(-1, -1, self.model_tester.config["num_codebooks"])
+ position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
+
+ # Combining common prefix with the unique ending tokens:
+ input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
+
+ # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
+ mask_shared_prefix = torch.tensor(
+ [
+ [
+ [
+ [1, 0, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0],
+ [1, 1, 1, 0, 1, 0],
+ [1, 1, 1, 0, 0, 1],
+ ]
+ ]
+ ],
+ )
+ # inverting the attention mask
+ mask_dtype = torch.float32
+ min_dtype = torch.finfo(mask_dtype).min
+ mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
+
+ # Creating a position_ids tensor. note the repeating figures in the end.
+ position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
+
+ return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+
+class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ # TODO: @eustlb, update with correct sesame's repo
+ self.model_checkpoint = "eustlb/csm-1b"
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ def _load_conversation(self):
+ ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ ds = ds.filter(lambda x: x["conversation_id"] == 0)
+ ds = ds.sort("turn_id")
+ return ds[0]
+
+ @slow
+ @require_torch_gpu
+ def test_1b_model_integration_generate(self):
+ """
+ Tests the generated tokens match the ones from the original model implementation.
+ Such tokens are to be retreived using https://gist.github.com/eustlb/d25577a357ddcf8f4a8cd0d00baca551, which is a script that infers the original model.
+ """
+ processor = AutoProcessor.from_pretrained(self.model_checkpoint)
+ prompt = "<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"
+
+ ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ audio = ds[0]["audio"]["array"]
+ inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(torch_device)
+
+ model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
+ output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
+
+ # fmt: off
+ EXPECTED_OUTPUT_TOKENS = torch.tensor([[
+ [1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601],
+ [1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 598, 80, 223, 1798, 251, 385, 1391, 1692, 1228, 1631, 1101, 866],
+ [778, 645, 830, 1812, 524, 1704, 1805, 1289, 74, 1069, 243, 1622, 1755, 1281, 1397, 620, 1962, 1995, 253, 1124, 1007, 518, 89, 559, 1304, 1482, 523, 1747, 1979, 1003, 1707, 1578],
+ [1356, 481, 642, 989, 287, 1819, 171, 1115, 824, 1253, 1488, 1074, 1019, 342, 279, 513, 1275, 1364, 893, 2007, 553, 407, 882, 1170, 1586, 485, 762, 559, 100, 542, 911, 1460],
+ [1860, 593, 1944, 404, 575, 545, 862, 830, 1002, 125, 2010, 268, 1779, 804, 811, 809, 255, 373, 387, 1756, 259, 822, 1191, 700, 1686, 390, 1676, 844, 2006, 286, 1376, 719],
+ [1165, 1047, 848, 212, 1018, 1470, 93, 1709, 1487, 1691, 1190, 275, 1278, 2018, 121, 1023, 485, 463, 39, 1825, 1936, 1817, 569, 209, 1553, 1599, 1137, 769, 968, 558, 1957, 265],
+ [902, 1608, 719, 850, 371, 1920, 75, 1917, 2005, 1238, 562, 1743, 713, 95, 1107, 1463, 696, 840, 8, 487, 1950, 1171, 1004, 1516, 1130, 303, 1866, 1728, 2046, 238, 265, 153],
+ [1932, 839, 334, 1167, 134, 2025, 40, 505, 1244, 1238, 1840, 800, 697, 72, 216, 486, 940, 1312, 510, 361, 549, 583, 1364, 844, 397, 1181, 1779, 962, 457, 1782, 1316, 465],
+ [31, 1558, 1048, 404, 354, 7, 827, 414, 1082, 807, 243, 1517, 801, 1364, 99, 1276, 1655, 1488, 1313, 464, 828, 1612, 774, 1558, 745, 1496, 960, 1874, 995, 1943, 255, 213],
+ [355, 1270, 413, 1519, 1659, 1904, 690, 552, 1279, 1821, 2022, 458, 1779, 2003, 604, 832, 661, 1295, 305, 1701, 173, 869, 230, 539, 1188, 669, 117, 692, 250, 388, 1995, 294],
+ [629, 199, 1899, 1123, 1070, 344, 578, 1795, 1451, 1257, 168, 1410, 1120, 1270, 316, 983, 1245, 1870, 165, 471, 966, 1337, 308, 1118, 746, 67, 1767, 1480, 1517, 1585, 871, 1110],
+ [1281, 1173, 784, 404, 368, 403, 580, 526, 853, 1692, 792, 895, 1286, 573, 1368, 896, 931, 1958, 1912, 644, 583, 1706, 1176, 1262, 1637, 315, 524, 1629, 795, 1211, 915, 533],
+ [9, 1783, 621, 1954, 1212, 993, 197, 977, 1662, 1340, 618, 1997, 1689, 1001, 74, 1765, 1865, 797, 1219, 1609, 671, 1491, 950, 1849, 1301, 2031, 875, 323, 203, 1063, 1490, 1538],
+ [1944, 1578, 1256, 1169, 790, 1444, 1382, 1616, 1100, 1264, 214, 1646, 488, 573, 1333, 285, 1954, 74, 1333, 674, 1303, 266, 622, 1290, 402, 109, 1331, 1666, 1347, 780, 106, 605],
+ [221, 161, 1322, 1, 565, 1507, 1403, 1091, 1557, 932, 1664, 1165, 1828, 1647, 2008, 1616, 648, 1113, 1870, 22, 734, 1458, 1940, 1756, 1689, 925, 1318, 1095, 985, 473, 604, 1974],
+ [1178, 597, 1804, 747, 1383, 360, 1497, 406, 1053, 1023, 1901, 56, 1221, 628, 75, 1729, 575, 1681, 840, 410, 650, 794, 1171, 1889, 187, 54, 1364, 1390, 505, 1285, 1814, 90],
+ [1432, 1221, 1800, 1873, 1255, 627, 41, 9, 630, 896, 1469, 1195, 1098, 145, 442, 1460, 13, 57, 2039, 1015, 149, 461, 1084, 1288, 1099, 910, 63, 157, 906, 111, 1394, 460],
+ [1352, 593, 307, 780, 1614, 1675, 1491, 1253, 723, 1793, 1032, 1486, 1805, 1904, 777, 398, 1791, 951, 770, 499, 1858, 244, 1372, 1514, 1858, 1200, 69, 181, 673, 1144, 1938, 1191],
+ [905, 403, 1626, 1529, 581, 1443, 976, 754, 1561, 1370, 1048, 253, 194, 1271, 853, 959, 1532, 30, 286, 1594, 1255, 1135, 1410, 1699, 1423, 2002, 260, 69, 941, 1640, 895, 722],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ ]])
+ # fmt: on
+
+ torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
+
+ @slow
+ @require_torch_gpu
+ def test_1b_model_integration_generate_no_audio(self):
+ """
+ Tests the generated tokens match the ones from the original model implementation.
+ Such tokens are to be retreived using https://gist.github.com/eustlb/aed822f765e928b9612e01b0d8836d69, which is a script that infers the original model.
+ """
+
+ processor = AutoProcessor.from_pretrained(self.model_checkpoint)
+
+ conversation = [
+ {"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
+ ]
+
+ inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True).to(torch_device)
+
+ model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
+ output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
+
+ print(output_tokens)
+ # fmt: off
+ EXPECTED_OUTPUT_TOKENS = torch.tensor([[
+ [1656, 629, 723, 1785, 206, 1873, 1059, 1190, 1833, 240, 618, 350, 156, 109, 2010, 452, 435, 1764, 77, 654, 1133, 908, 1095, 74, 804, 494, 1760, 1343, 1312, 1464, 1657, 324],
+ [366, 1532, 1945, 21, 145, 1428, 1417, 1987, 1793, 1444, 356, 1491, 849, 333, 788, 426, 1423, 1004, 414, 1823, 1169, 257, 1892, 696, 1572, 998, 1098, 523, 390, 1977, 546, 1692],
+ [1343, 1382, 1288, 1744, 1685, 1154, 1837, 1156, 1680, 1641, 1479, 1548, 632, 824, 694, 2010, 671, 1251, 1822, 343, 638, 1372, 696, 1272, 144, 125, 1332, 579, 936, 77, 159, 357],
+ [456, 1534, 349, 274, 1956, 1502, 1268, 1038, 1911, 523, 1360, 1159, 761, 293, 718, 1143, 63, 705, 168, 550, 413, 1372, 1771, 787, 631, 693, 784, 1789, 2039, 1131, 1601, 918],
+ [456, 829, 2026, 1108, 1649, 207, 1308, 1440, 1192, 1394, 426, 546, 590, 36, 1682, 1827, 1387, 1425, 1909, 1500, 1438, 1297, 5, 888, 948, 1745, 1304, 1364, 1692, 131, 300, 1908],
+ [2027, 1431, 1037, 1789, 1296, 1264, 1331, 1787, 1235, 1902, 1161, 1591, 590, 561, 1633, 1218, 510, 148, 1962, 118, 212, 608, 565, 1869, 583, 598, 532, 658, 1416, 9, 1172, 493],
+ [1215, 460, 1722, 317, 1423, 716, 1589, 1177, 1927, 1860, 1756, 1552, 1674, 643, 74, 1256, 587, 1742, 771, 2028, 469, 1070, 1683, 1614, 699, 494, 2020, 139, 1365, 1171, 171, 904],
+ [1615, 339, 323, 317, 469, 714, 104, 2015, 1407, 278, 468, 77, 2007, 650, 1630, 269, 168, 934, 1544, 58, 1487, 1373, 705, 874, 1252, 2031, 1995, 254, 1334, 1171, 1911, 1607],
+ [1259, 693, 666, 1700, 1115, 607, 982, 769, 1106, 1500, 101, 88, 1698, 1864, 1358, 1594, 192, 153, 1868, 1654, 604, 1948, 526, 778, 172, 1664, 1966, 99, 1334, 1030, 1349, 1209],
+ [1211, 579, 1369, 492, 1725, 203, 1125, 778, 701, 1982, 1420, 155, 736, 1145, 2018, 609, 658, 561, 1147, 923, 1794, 1753, 116, 1374, 612, 956, 1587, 392, 1062, 2047, 901, 1931],
+ [460, 1093, 1346, 1917, 1223, 470, 271, 390, 547, 112, 143, 1633, 1030, 643, 96, 1759, 920, 1959, 75, 1280, 1630, 999, 333, 853, 1110, 1291, 1911, 57, 171, 1658, 1704, 1508],
+ [908, 500, 393, 184, 1437, 482, 2008, 1834, 356, 1435, 1550, 1407, 1236, 109, 1167, 452, 1141, 934, 207, 957, 660, 670, 28, 1066, 1252, 1932, 669, 906, 1904, 1820, 2043, 881],
+ [1599, 1031, 1474, 336, 1540, 571, 437, 1440, 1616, 1365, 1412, 1246, 400, 405, 1776, 96, 296, 38, 1597, 466, 1630, 1256, 1940, 887, 1769, 294, 285, 842, 1756, 1619, 451, 1529],
+ [1615, 339, 1722, 525, 942, 105, 1365, 670, 785, 1316, 465, 1860, 438, 968, 547, 1938, 1816, 1429, 1065, 1942, 660, 1446, 1093, 1066, 931, 121, 688, 1033, 1178, 754, 1783, 94],
+ [912, 1354, 598, 254, 341, 1980, 1166, 585, 1302, 473, 554, 242, 174, 2030, 2011, 325, 978, 1690, 258, 396, 1831, 1768, 1291, 1699, 2001, 433, 1414, 2012, 1045, 511, 533, 1104],
+ [80, 1791, 1062, 1136, 391, 568, 1651, 101, 959, 2043, 1683, 760, 794, 181, 570, 540, 1599, 20, 1017, 973, 1654, 396, 586, 778, 2044, 1664, 1911, 929, 66, 897, 510, 643],
+ [1161, 1093, 161, 1296, 589, 54, 906, 981, 1927, 605, 516, 1731, 1461, 1204, 1902, 920, 1488, 177, 805, 1402, 610, 1446, 1154, 1067, 2025, 645, 762, 1715, 415, 1658, 1713, 1607],
+ [374, 1444, 1577, 792, 1450, 628, 604, 1729, 322, 514, 1725, 540, 1070, 575, 653, 800, 250, 187, 569, 349, 354, 1573, 176, 793, 897, 359, 536, 276, 1224, 23, 145, 1287],
+ [1184, 415, 1644, 1737, 1788, 385, 784, 1861, 1172, 1118, 367, 1156, 234, 1946, 1742, 981, 828, 1798, 1821, 361, 1148, 670, 518, 1288, 761, 1050, 1642, 1006, 1747, 840, 1599, 720],
+ [1141, 1731, 1670, 1542, 1347, 1907, 683, 753, 1347, 68, 2031, 153, 556, 719, 736, 1759, 1131, 1073, 1747, 1730, 1487, 1137, 1869, 1624, 699, 1900, 748, 49, 1312, 735, 726, 1268],
+ [1141, 1383, 405, 1033, 490, 488, 1102, 471, 713, 1630, 447, 703, 1495, 1001, 1855, 354, 456, 411, 786, 853, 168, 407, 116, 699, 605, 128, 532, 1076, 208, 447, 1448, 1071],
+ [345, 1013, 948, 1728, 1837, 337, 930, 1226, 1643, 1729, 983, 1688, 2009, 435, 1358, 721, 42, 1779, 1332, 1077, 1873, 128, 1327, 125, 1226, 1704, 705, 1459, 1449, 862, 155, 1870],
+ [336, 904, 684, 184, 1542, 714, 1752, 1180, 1373, 1816, 504, 1716, 1066, 1086, 1212, 530, 1413, 1278, 75, 1347, 82, 1623, 1307, 1717, 1861, 494, 888, 1589, 670, 1999, 905, 1430],
+ [578, 554, 14, 523, 1016, 300, 1589, 1017, 356, 1583, 1654, 414, 449, 376, 1413, 58, 706, 963, 388, 1626, 131, 352, 1024, 1054, 2025, 1561, 77, 1589, 1486, 431, 1249, 1508],
+ [184, 2043, 169, 1673, 580, 162, 1752, 397, 1119, 2009, 697, 150, 1475, 157, 1523, 1402, 575, 86, 1373, 1230, 1564, 1308, 626, 1093, 1603, 1446, 1390, 1543, 1778, 1142, 1357, 1831],
+ [1484, 1987, 932, 1728, 1504, 1618, 291, 1865, 1151, 460, 1792, 141, 234, 2043, 829, 513, 435, 791, 1037, 1541, 65, 424, 1589, 1711, 312, 1306, 212, 686, 673, 984, 1914, 1549],
+ [513, 1536, 1844, 1319, 572, 1069, 121, 735, 1949, 1211, 1362, 1027, 105, 1379, 315, 1782, 706, 1658, 1510, 1989, 1443, 1690, 822, 1614, 1194, 1460, 992, 2040, 1178, 1474, 1110, 1326],
+ [1858, 194, 1594, 1935, 1622, 1892, 1577, 137, 1907, 2015, 757, 414, 1823, 836, 496, 530, 1385, 1503, 1065, 1554, 664, 525, 1031, 433, 69, 466, 1016, 1846, 1609, 1658, 911, 94],
+ [1134, 1744, 323, 691, 1837, 347, 1871, 172, 811, 91, 1883, 436, 1912, 23, 1336, 1684, 519, 1612, 1219, 1402, 728, 1953, 1658, 641, 27, 1340, 436, 139, 2008, 1030, 159, 324],
+ [1270, 1536, 1639, 414, 1387, 1170, 1067, 1701, 1414, 505, 1122, 36, 1731, 350, 1552, 1214, 1444, 30, 107, 172, 480, 1858, 655, 168, 1107, 691, 1272, 797, 1656, 548, 1407, 1375],
+ [1270, 286, 1371, 1552, 1622, 1739, 1348, 2018, 345, 1537, 1941, 2024, 1423, 740, 284, 513, 91, 1228, 2015, 385, 992, 39, 813, 803, 2025, 497, 663, 462, 1609, 334, 927, 1470],
+ [1718, 994, 265, 1421, 1622, 1098, 845, 1868, 832, 459, 447, 619, 1970, 929, 513, 63, 1448, 1509, 1219, 1942, 285, 1373, 1259, 1004, 11, 1040, 1984, 57, 188, 1687, 1475, 805],
+ [1157, 832, 480, 1225, 1019, 347, 326, 999, 125, 1542, 118, 1383, 1343, 1077, 1821, 1602, 1978, 1642, 618, 808, 692, 1953, 1353, 963, 619, 1291, 1016, 1458, 1995, 1688, 1872, 1718],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ ]])
+ # fmt: on
+
+ torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
+
+ @slow
+ @require_torch_gpu
+ def test_1b_model_integration_generate_multiple_audio(self):
+ """
+ Test the generated tokens match the ones from the original model implementation.
+ Such tokens are to be retreived using https://gist.github.com/eustlb/0c94de002e1325abb61d32217f74c0f8, which is a script that infers the original model.
+ """
+ processor = AutoProcessor.from_pretrained(self.model_checkpoint)
+
+ ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ conversation = []
+
+ # context
+ for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
+ conversation.append(
+ {
+ "role": f"{speaker_id}",
+ "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
+ }
+ )
+
+ # text prompt
+ conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
+
+ inputs = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ ).to(torch_device)
+
+ model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
+ output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
+
+ # fmt: off
+ EXPECTED_OUTPUT_TOKENS = torch.tensor([[
+ [420, 1189, 1311, 318, 359, 694, 1550, 1044, 1614, 1437, 1978, 537, 554, 1681, 147, 1225, 422, 1357, 1681, 1619, 165, 641, 1132, 1975, 1568, 406, 756, 503, 1673, 1428, 762, 781],
+ [1848, 1412, 957, 1656, 871, 540, 1999, 175, 711, 1383, 1814, 104, 742, 1285, 733, 1251, 1165, 1915, 1392, 645, 1804, 913, 1772, 632, 376, 1507, 1132, 725, 716, 1121, 1769, 1509],
+ [429, 1138, 895, 1018, 1099, 257, 1395, 1015, 576, 1599, 497, 19, 1858, 1437, 282, 357, 1143, 828, 1481, 70, 985, 551, 935, 278, 1102, 1453, 1902, 755, 526, 498, 1441, 1733],
+ [546, 343, 1547, 879, 2039, 692, 1999, 1150, 1969, 1866, 1178, 199, 1913, 1738, 1530, 1728, 1193, 74, 695, 612, 1095, 1597, 1381, 683, 1385, 2045, 1069, 865, 438, 70, 1437, 318],
+ [1741, 1621, 733, 1580, 1006, 1790, 1031, 1563, 569, 1822, 1229, 854, 142, 1554, 792, 741, 147, 552, 731, 772, 908, 831, 1291, 1819, 296, 290, 1871, 100, 1904, 1420, 1903, 1653],
+ [1264, 1576, 963, 12, 1403, 453, 259, 1359, 1270, 466, 1744, 1579, 1081, 1691, 1495, 1293, 110, 1020, 2042, 189, 1358, 955, 784, 1317, 2, 1794, 388, 376, 327, 511, 866, 1308],
+ [1407, 1412, 1665, 1683, 284, 874, 1859, 326, 1491, 1343, 777, 695, 1424, 396, 274, 202, 178, 747, 470, 1805, 1414, 2000, 127, 1884, 531, 215, 1322, 1098, 1674, 1227, 1092, 204],
+ [584, 637, 1665, 1683, 1136, 1201, 212, 310, 1441, 1619, 190, 1611, 1629, 2011, 1754, 1587, 413, 1287, 1251, 1382, 1904, 444, 1665, 1047, 1982, 1169, 1200, 809, 117, 327, 958, 1877],
+ [471, 1469, 1679, 1184, 343, 974, 1442, 897, 1888, 1468, 1092, 1398, 1714, 963, 1577, 1797, 766, 565, 403, 920, 1806, 466, 1193, 446, 825, 775, 1886, 1095, 159, 1085, 858, 504],
+ [28, 1511, 1510, 1580, 447, 1934, 1031, 1439, 202, 1435, 474, 1731, 724, 1080, 1121, 421, 625, 1410, 95, 605, 815, 1825, 127, 785, 900, 1673, 178, 1242, 2033, 1230, 350, 139],
+ [20, 1215, 253, 955, 871, 1689, 1986, 24, 1648, 423, 562, 1937, 1146, 26, 1266, 346, 188, 318, 179, 1164, 1100, 1978, 478, 1192, 715, 392, 1837, 425, 1492, 766, 1651, 822],
+ [1879, 1401, 1444, 723, 1754, 732, 1307, 702, 1768, 2013, 1284, 577, 1287, 1532, 647, 189, 903, 587, 800, 152, 898, 182, 2016, 639, 1074, 1220, 1934, 264, 250, 745, 1652, 536],
+ [1874, 1526, 232, 1580, 1980, 988, 1623, 341, 1768, 956, 1430, 1667, 1687, 1289, 826, 1378, 173, 1466, 479, 835, 1786, 1671, 328, 131, 815, 871, 379, 1329, 440, 1117, 392, 272],
+ [1762, 426, 1350, 1590, 314, 190, 1514, 344, 1926, 822, 534, 523, 703, 36, 379, 494, 464, 1886, 1555, 1318, 1654, 1469, 1976, 304, 218, 655, 1826, 958, 502, 326, 1898, 861],
+ [1577, 386, 503, 1492, 698, 405, 1031, 349, 1804, 2012, 1450, 996, 1140, 26, 449, 33, 1917, 354, 702, 1255, 1942, 1184, 864, 2045, 514, 744, 466, 54, 37, 486, 362, 525],
+ [1109, 1920, 445, 1719, 1670, 1220, 745, 40, 171, 1921, 999, 104, 489, 1911, 883, 306, 649, 1751, 762, 1183, 1085, 1112, 1912, 2035, 1940, 1129, 1592, 1276, 1570, 1236, 738, 209],
+ [1837, 990, 1063, 318, 1398, 1838, 1678, 906, 754, 802, 562, 353, 1389, 207, 1319, 1188, 2013, 1079, 888, 1706, 1042, 657, 482, 953, 94, 2007, 871, 485, 1596, 275, 410, 1855],
+ [872, 974, 1344, 1798, 655, 805, 1604, 1913, 455, 615, 1827, 966, 1330, 1826, 1285, 359, 544, 221, 1538, 1658, 374, 1352, 1714, 1925, 235, 65, 350, 931, 1009, 1164, 218, 736],
+ [1547, 617, 1622, 740, 655, 265, 1324, 1265, 1449, 482, 1037, 105, 1128, 701, 1866, 1674, 1999, 1302, 985, 1942, 663, 449, 1881, 698, 805, 1446, 1742, 1192, 1623, 605, 948, 2],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ ]])
+ # fmt: on
+
+ torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
+
+ @slow
+ @require_torch_gpu
+ def test_1b_model_integration_generate_batched(self):
+ """
+ Test the generated tokens match the ones from the original model implementation.
+ Such tokens are to be retreived using https://gist.github.com/eustlb/bcc532b53161bc31da3d66cb07ae193f, which is a script that infers the original model.
+ """
+ processor = AutoProcessor.from_pretrained(self.model_checkpoint)
+
+ ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ conversation = [
+ [
+ {
+ "role": f"{ds[0]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[0]["text"]},
+ {"type": "audio", "path": ds[0]["audio"]["array"]},
+ ],
+ },
+ {
+ "role": f"{ds[1]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[1]["text"]},
+ ],
+ },
+ ],
+ [
+ {
+ "role": f"{ds[0]['speaker_id']}",
+ "content": [
+ {"type": "text", "text": ds[0]["text"]},
+ ],
+ }
+ ],
+ ]
+
+ inputs = processor.apply_chat_template(
+ conversation,
+ tokenize=True,
+ return_dict=True,
+ ).to(torch_device)
+
+ model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
+ output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
+
+ # fmt: off
+ EXPECTED_OUTPUT_TOKENS = torch.tensor([
+ [
+ [1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601],
+ [1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 1656, 80, 1947, 1666, 933, 1950, 1544, 1577, 1612, 1791, 1883, 765],
+ [778, 645, 830, 1051, 524, 1704, 1805, 1438, 211, 906, 691, 814, 1798, 1642, 1042, 284, 1906, 1513, 520, 137, 1052, 1548, 423, 1564, 330, 873, 1381, 188, 317, 1503, 1707, 1744],
+ [1416, 864, 242, 1653, 604, 1577, 202, 1808, 926, 1867, 204, 134, 1096, 1765, 496, 1680, 268, 1796, 2024, 1989, 583, 183, 952, 105, 765, 1534, 669, 895, 2008, 11, 1199, 195],
+ [1356, 796, 25, 1580, 15, 344, 1730, 99, 1330, 315, 955, 1964, 1731, 543, 1159, 1860, 671, 732, 63, 382, 143, 395, 1749, 1421, 1640, 1340, 650, 100, 171, 1346, 41, 806],
+ [1860, 1835, 823, 388, 254, 1734, 1135, 324, 1508, 983, 937, 1703, 1541, 875, 1319, 799, 1259, 1175, 1295, 807, 261, 760, 1916, 1606, 1616, 1894, 1605, 441, 387, 167, 2016, 222],
+ [1165, 919, 1318, 54, 1727, 1766, 777, 1128, 623, 353, 1840, 241, 977, 424, 1055, 898, 395, 655, 1695, 1084, 1346, 616, 1028, 1927, 603, 858, 758, 1539, 0, 1655, 1853, 1661],
+ [902, 1746, 1318, 298, 1982, 1184, 775, 328, 1676, 871, 133, 1374, 1927, 1984, 698, 1037, 100, 1884, 1596, 429, 1794, 2046, 105, 2037, 1767, 178, 176, 1293, 1893, 1780, 1832, 1382],
+ [1932, 714, 1084, 1167, 624, 509, 1213, 651, 1000, 1686, 1537, 555, 461, 623, 1433, 1089, 1212, 1628, 834, 1111, 943, 1816, 1947, 1063, 354, 1843, 1741, 2015, 404, 928, 1488, 168],
+ [1437, 314, 1356, 404, 1274, 2016, 998, 1350, 155, 553, 368, 1501, 1431, 1563, 1105, 1353, 535, 908, 1305, 1214, 1656, 65, 1469, 1517, 480, 252, 1289, 696, 302, 632, 246, 72],
+ [724, 848, 1140, 927, 1669, 296, 447, 1708, 1898, 685, 1041, 1685, 708, 1510, 1623, 876, 11, 99, 43, 586, 1705, 1753, 1477, 1191, 583, 1249, 1613, 992, 1319, 677, 418, 668],
+ [925, 54, 1810, 674, 1306, 848, 573, 1772, 105, 301, 1753, 989, 440, 1057, 823, 1313, 1663, 750, 1477, 102, 1437, 1114, 399, 1440, 319, 118, 1827, 295, 1429, 139, 1594, 55],
+ [629, 149, 784, 838, 984, 604, 685, 1229, 1432, 859, 1526, 1336, 1949, 281, 988, 1260, 52, 6, 1216, 1542, 1426, 1938, 253, 280, 1319, 794, 901, 843, 615, 437, 814, 20],
+ [1281, 502, 1237, 404, 625, 1444, 397, 1999, 2016, 1686, 533, 1785, 1152, 1245, 579, 1906, 1204, 549, 1334, 536, 1351, 1979, 208, 111, 2011, 751, 677, 1948, 1772, 1525, 2038, 419],
+ [9, 490, 869, 2026, 1928, 1489, 587, 549, 1241, 460, 1458, 1636, 924, 222, 1246, 480, 706, 398, 75, 1717, 604, 1446, 333, 237, 805, 1446, 421, 1343, 78, 1260, 1872, 1116],
+ [1944, 755, 375, 332, 1464, 828, 1273, 579, 1457, 353, 1510, 1910, 1609, 705, 400, 1666, 227, 1544, 1270, 136, 1857, 1975, 1762, 2006, 1102, 221, 1965, 151, 2041, 198, 1830, 287],
+ [221, 502, 440, 247, 181, 1912, 42, 357, 1883, 596, 919, 953, 1774, 772, 915, 188, 438, 1226, 544, 1313, 726, 1298, 85, 677, 566, 1581, 30, 341, 878, 1732, 591, 1446],
+ [1178, 1690, 320, 1746, 1798, 685, 1941, 666, 832, 623, 1907, 128, 337, 1779, 824, 923, 1041, 287, 1165, 437, 1803, 1222, 870, 646, 358, 220, 2009, 735, 468, 1908, 1349, 1603],
+ [1432, 1286, 540, 1687, 1741, 951, 299, 1233, 1061, 1128, 985, 953, 1917, 198, 2031, 1559, 1096, 1455, 780, 437, 163, 1268, 649, 1029, 1081, 1518, 304, 1638, 814, 364, 140, 1385],
+ [905, 463, 1739, 1063, 351, 936, 1652, 101, 1323, 1731, 298, 1193, 266, 1554, 1837, 1659, 409, 1739, 1012, 725, 851, 1909, 213, 1918, 1759, 1561, 1250, 970, 1571, 352, 911, 195],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ ],
+ [
+ [1375, 203, 265, 164, 200, 1867, 976, 924, 1972, 1637, 1048, 271, 1912, 1430, 853, 1942, 260, 1642, 400, 57, 1376, 1626, 1821, 1163, 619, 777, 1076, 951, 389, 1820, 84, 1417],
+ [914, 527, 286, 968, 305, 1314, 805, 1703, 87, 559, 1980, 1124, 1726, 36, 1139, 618, 1628, 519, 1943, 781, 400, 1265, 438, 113, 87, 856, 465, 162, 1099, 352, 1141, 274],
+ [1408, 6, 126, 2009, 90, 996, 934, 134, 1857, 126, 602, 876, 1092, 1962, 1205, 828, 707, 1063, 393, 1533, 123, 1086, 1749, 1324, 1, 1763, 1707, 1191, 34, 1323, 1017, 1787],
+ [1000, 683, 1630, 703, 1574, 587, 25, 1049, 213, 1270, 1641, 1072, 1892, 1634, 1603, 90, 867, 2037, 1021, 715, 206, 507, 1138, 959, 1822, 1785, 280, 1100, 1660, 251, 1903, 988],
+ [1657, 1981, 246, 1048, 1952, 451, 305, 423, 2000, 416, 756, 1748, 7, 748, 1866, 1795, 1682, 1832, 338, 212, 1685, 518, 154, 1407, 416, 765, 776, 25, 55, 458, 612, 262],
+ [1034, 564, 667, 1474, 1212, 350, 712, 941, 1151, 1182, 1280, 640, 924, 1722, 1816, 458, 226, 359, 1518, 102, 1203, 459, 676, 1788, 1110, 393, 1974, 1721, 795, 1459, 798, 1723],
+ [742, 1616, 119, 653, 441, 679, 246, 1432, 486, 1615, 1191, 500, 650, 223, 687, 1765, 1875, 963, 1385, 863, 151, 1771, 458, 1170, 737, 1932, 785, 1954, 1067, 16, 1986, 2029],
+ [1437, 1078, 1767, 1452, 1392, 45, 2010, 1664, 245, 2015, 1416, 1055, 457, 985, 740, 1594, 1562, 1838, 258, 1431, 701, 604, 1813, 352, 792, 632, 21, 895, 70, 609, 850, 1599],
+ [983, 1961, 54, 135, 846, 711, 473, 1630, 1373, 1094, 251, 525, 632, 1014, 1594, 1594, 1752, 398, 1266, 1357, 942, 1680, 191, 874, 483, 1291, 381, 1873, 1964, 1278, 1477, 122],
+ [1663, 1969, 1887, 113, 145, 251, 1133, 156, 245, 1641, 209, 1322, 2037, 836, 539, 667, 940, 797, 1758, 1357, 191, 1137, 587, 1699, 27, 701, 395, 99, 1682, 876, 762, 839],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ]
+ ])
+ # fmt: on
+
+ torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
diff --git a/tests/models/csm/test_processor_csm.py b/tests/models/csm/test_processor_csm.py
new file mode 100644
index 0000000000..da96381246
--- /dev/null
+++ b/tests/models/csm/test_processor_csm.py
@@ -0,0 +1,140 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import shutil
+import tempfile
+import unittest
+
+import jinja2
+import numpy as np
+
+from transformers import CsmProcessor
+from transformers.testing_utils import require_torch
+from transformers.utils import is_torch_available
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+
+@require_torch
+class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+ processor_class = CsmProcessor
+
+ @classmethod
+ def setUpClass(cls):
+ # TODO: @eustlb, change for hf-internal-testing/csm-1b
+ cls.checkpoint = "eustlb/csm-1b"
+ processor = CsmProcessor.from_pretrained(cls.checkpoint)
+ cls.audio_token = processor.audio_token
+ cls.audio_token_id = processor.audio_token_id
+ cls.pad_token_id = processor.tokenizer.pad_token_id
+ cls.bos_token_id = processor.tokenizer.bos_token_id
+ cls.tmpdirname = tempfile.mkdtemp()
+ processor.save_pretrained(cls.tmpdirname)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tmpdirname, ignore_errors=True)
+
+ def prepare_processor_dict(self):
+ return {"chat_template": "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"} # fmt: skip
+
+ def test_chat_template_is_saved(self):
+ processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
+ processor_dict_loaded = json.loads(processor_loaded.to_json_string())
+ # chat templates aren't serialized to json in processors
+ self.assertFalse("chat_template" in processor_dict_loaded.keys())
+
+ # they have to be saved as separate file and loaded back from that file
+ # so we check if the same template is loaded
+ processor_dict = self.prepare_processor_dict()
+ self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
+
+ def test_apply_chat_template(self):
+ # Message contains content which a mix of lists with images and image urls and string
+ messages = [
+ {
+ "role": "0",
+ "content": [
+ {"type": "text", "text": "This is a test sentence 0."},
+ {"type": "audio"},
+ ],
+ },
+ {
+ "role": "1",
+ "content": [
+ {"type": "text", "text": "This is a test sentence 1."},
+ {"type": "audio"},
+ ],
+ },
+ {
+ "role": "0",
+ "content": [
+ {"type": "text", "text": "This is a prompt."},
+ ],
+ },
+ ]
+ processor = CsmProcessor.from_pretrained(self.tmpdirname)
+ rendered = processor.apply_chat_template(messages, tokenize=False)
+
+ expected_rendered = (
+ "<|begin_of_text|>[0]This is a test sentence 0.<|end_of_text|>"
+ "<|AUDIO|><|audio_eos|>"
+ "<|begin_of_text|>[1]This is a test sentence 1.<|end_of_text|>"
+ "<|AUDIO|><|audio_eos|>"
+ "<|begin_of_text|>[0]This is a prompt.<|end_of_text|>"
+ )
+ self.assertEqual(rendered, expected_rendered)
+
+ messages = [
+ {
+ "role": "0",
+ "content": [
+ {"type": "text", "text": "This is a test sentence."},
+ ],
+ },
+ {
+ "role": "1",
+ "content": [
+ {"type": "text", "text": "This is a test sentence."},
+ ],
+ },
+ ]
+
+ # this should raise an error because the CSM processor requires audio content in the messages expect the last one
+ with self.assertRaises(jinja2.exceptions.TemplateError):
+ input_ids = processor.apply_chat_template(messages, tokenize=False)
+
+ # now let's very that it expands audio tokens correctly
+ messages = [
+ {
+ "role": "0",
+ "content": [
+ {"type": "text", "text": "This is a test sentence."},
+ {"type": "audio", "audio": np.zeros(4096)},
+ ],
+ },
+ ]
+
+ input_ids = processor.apply_chat_template(messages, tokenize=True)
+
+ # 4096 audio input values should give 3 audio tokens
+ expected_ids = torch.tensor(
+ [[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
+ )
+ torch.testing.assert_close(input_ids, expected_ids)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 62508002fe..974bfb7b5a 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -4350,8 +4350,8 @@ class ModelTesterMixin:
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
- normalized_0 = F.softmax(out_last_tokens)
- normalized_1 = F.softmax(out_shared_prefix_last_tokens)
+ normalized_0 = F.softmax(out_last_tokens, dim=-1)
+ normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@slow
@@ -4403,7 +4403,7 @@ class ModelTesterMixin:
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
- batch_size, sequence_length = inputs["input_ids"].shape
+ batch_size, sequence_length = inputs["input_ids"].shape[:2]
vocab_size = config.get_text_config().vocab_size
model = model_class(config).to(device=torch_device).eval()
# some models have labels but `logits_to_keep` should not be used in train mode
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 1164ac9db0..960a73e734 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -159,6 +159,9 @@ IGNORE_NON_TESTED = (
"InternVLVisionModel", # Building part of bigger (tested) model
"JanusVisionModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model
+ "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
+ "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
+ "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
]
)
@@ -368,6 +371,10 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"Qwen2_5OmniToken2WavModel", # Building part of a bigger model
"Qwen2_5OmniToken2WavBigVGANModel", # Building part of a bigger model
"Qwen2_5OmniToken2WavDiTModel", # Building part of a bigger model
+ "CsmBackboneModel", # Building part of a bigger model
+ "CsmDepthDecoderModel", # Building part of a bigger model
+ "CsmDepthDecoderForCausalLM", # Building part of a bigger model
+ "CsmForConditionalGeneration", # Building part of a bigger model
]
# DO NOT edit this list!