From 798f948e88fd0b93fc515ec6b96e0503b78ad6ba Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Wed, 7 May 2025 10:20:13 -0400 Subject: [PATCH] Add CSM model (#36719) * draft structure * depth decoder with forward pre hook * full model forward draft * draft update * depth decoder update * ConversationalSpeechModelForCausalLM udpates * add generate * max length criteria small fix * udpate * updates * generation update * update in loss compute * conversion script * update for correct input embeddings * handle interleaved rope * update * update * update * support compile * update training * add doc * update doc * correct inits * ConversationalSpeechModel -> Csm * conf update * name update * tests CsmForCausalLMTest * convert use cached_file * conf + modeling updates * generate utils handle third dim shape * integration test * modeling + conf updates * common test handle more than 2 dims * add nested audio list utils * processing handle nested audio list * csm processing draft * mimi util * init updates * modular update * convert modular * processing update * csm tests update * generate tests handle third dim * generate utils handle third dim * propagate _get_initial_cache_position update * tied_weight_keys update + convert correctly * fix inputs_embeds * revert audio nested list * batch inference update + return audio * audio_utils update * processor update * some more integration tests * remove old test * porcessing output labels * improve * fix * update rope values with equivalent ones * conversion update * udpate tests * handle depth decoder generation config * remove default eos_token_id * make style * revert modeling_mimi * add default generation_config * remove sdpa since handled by default * make * fix conflict * fix conflicts * correct naming * correct imports * make * causal -> conditional naming * causal -> conditional naming * auto update * make * make * add doc * test update * fix weight init * audio tokens offsets as buffer * 4d mask in conditional class * make * doc update * fix causal mask * fix causal mask * doc update * doc update * add processor doc * update doc * fix 4d causal mask * update make_list_of_audio * do not default to mutable * remove duplicates * remove useless reset_parameters * use GradientCheckpointingLayer * use can_return_tuple * formatting * prepend placeholder in _sample * torch compile fix * some more fixies * convert modular * fix * default max_length in convert * handle depth decoder generation config correctly * clearer formulation * handle output_loading_info * handle softmax warning * add doc * propagate _get_initial_cache_position changes * generation in its own module * add processor tests * fix compile witu cuda graphs * fix compile with cuda graphs * add csm.md * include CSM loss * doc nit * doc nit * doc nit * Update docs/source/en/model_doc/csm.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add save_audio to processor * Update src/transformers/models/csm/modular_csm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * doc update * simplify audio_codes_mask computation * doc update * simplify loss computation * fix static cache test * fix * remove comment * simplify encoded length computation * use hf-internal-testing * doc update * cast to float before numpy * nit * mem efficient codebook head * nit * cat input values with cutoffs --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/csm.md | 377 ++++ src/transformers/audio_utils.py | 37 +- .../generation/stopping_criteria.py | 2 +- src/transformers/generation/utils.py | 42 +- src/transformers/loss/loss_utils.py | 1 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/csm/__init__.py | 28 + .../models/csm/configuration_csm.py | 440 +++++ src/transformers/models/csm/convert_csm.py | 339 ++++ src/transformers/models/csm/generation_csm.py | 491 +++++ src/transformers/models/csm/modeling_csm.py | 1710 +++++++++++++++++ src/transformers/models/csm/modular_csm.py | 1042 ++++++++++ src/transformers/models/csm/processing_csm.py | 364 ++++ .../gpt_bigcode/modeling_gpt_bigcode.py | 6 +- .../models/janus/modeling_janus.py | 2 +- .../models/janus/modular_janus.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 65 + .../qwen2_5_omni/modeling_qwen2_5_omni.py | 4 +- .../qwen2_5_omni/modular_qwen2_5_omni.py | 4 +- src/transformers/processing_utils.py | 2 +- tests/generation/test_utils.py | 102 +- tests/models/csm/__init__.py | 0 tests/models/csm/test_modeling_csm.py | 693 +++++++ tests/models/csm/test_processor_csm.py | 140 ++ tests/test_modeling_common.py | 6 +- utils/check_repo.py | 7 + 29 files changed, 5827 insertions(+), 86 deletions(-) create mode 100644 docs/source/en/model_doc/csm.md create mode 100644 src/transformers/models/csm/__init__.py create mode 100644 src/transformers/models/csm/configuration_csm.py create mode 100644 src/transformers/models/csm/convert_csm.py create mode 100644 src/transformers/models/csm/generation_csm.py create mode 100644 src/transformers/models/csm/modeling_csm.py create mode 100644 src/transformers/models/csm/modular_csm.py create mode 100644 src/transformers/models/csm/processing_csm.py create mode 100644 tests/models/csm/__init__.py create mode 100644 tests/models/csm/test_modeling_csm.py create mode 100644 tests/models/csm/test_processor_csm.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7d505721f0..b848976d02 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -825,6 +825,8 @@ title: Bark - local: model_doc/clap title: CLAP + - local: model_doc/csm + title: CSM - local: model_doc/dac title: dac - local: model_doc/encodec diff --git a/docs/source/en/model_doc/csm.md b/docs/source/en/model_doc/csm.md new file mode 100644 index 0000000000..2d916da161 --- /dev/null +++ b/docs/source/en/model_doc/csm.md @@ -0,0 +1,377 @@ + + +# Csm + +## Overview + +The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio. + +**Model Architecture:** +CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio. + +The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face. + +
+ +
+ +## Usage Tips + +### Without Conversational Context + +CSM can be used to simply generate speech from a text prompt: + +```python +import torch +from transformers import CsmForConditionalGeneration, AutoProcessor + +model_id = "eustlb/csm-1b" +device = "cuda" if torch.cuda.is_available() else "cpu" + +# load the model and the processor +processor = AutoProcessor.from_pretrained(model_id) +model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) + +# prepare the inputs +text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0 +inputs = processor(text, add_special_tokens=True).to(device) + +# another equivalent way to prepare the inputs +conversation = [ + {"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]}, +] +inputs = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, +).to(device) + +# infer the model +audio = model.generate(**inputs, output_audio=True) +processor.save_audio(audio, "example_without_context.wav") +``` + +### With Conversational Context + +CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation: + +```python +import torch +from transformers import CsmForConditionalGeneration, AutoProcessor +from datasets import load_dataset, Audio + +model_id = "eustlb/csm-1b" +device = "cuda" if torch.cuda.is_available() else "cpu" + +# load the model and the processor +processor = AutoProcessor.from_pretrained(model_id) +model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) + +# prepare the inputs +ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") +# ensure the audio is 24kHz +ds = ds.cast_column("audio", Audio(sampling_rate=24000)) +conversation = [] + +# 1. context +for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): + conversation.append( + { + "role": f"{speaker_id}", + "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], + } + ) + +# 2. text prompt +conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]}) + +inputs = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, +).to(device) + +# infer the model +audio = model.generate(**inputs, output_audio=True) +processor.save_audio(audio, "example_with_context.wav") +``` + +### Batched Inference + +CSM supports batched inference! + +```python +import torch +from transformers import CsmForConditionalGeneration, AutoProcessor +from datasets import load_dataset, Audio + +model_id = "eustlb/csm-1b" +device = "cuda" if torch.cuda.is_available() else "cpu" + +# load the model and the processor +processor = AutoProcessor.from_pretrained(model_id) +model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) + +# prepare the inputs +ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") +# ensure the audio is 24kHz +ds = ds.cast_column("audio", Audio(sampling_rate=24000)) +# here a batch with two prompts +conversation = [ + [ + { + "role": f"{ds[0]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[0]["text"]}, + {"type": "audio", "path": ds[0]["audio"]["array"]}, + ], + }, + { + "role": f"{ds[1]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[1]["text"]}, + ], + }, + ], + [ + { + "role": f"{ds[0]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[0]["text"]}, + ], + } + ], +] +inputs = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, +).to(device) + +audio = model.generate(**inputs, output_audio=True) +processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))]) +``` + +### Making The Model Go Brrr + +CSM supports full-graph compilation with CUDA graphs! + +```python +import torch +import copy +from transformers import CsmForConditionalGeneration, AutoProcessor +from datasets import load_dataset + +model_id = "eustlb/csm-1b" +device = "cuda" + +# set logs to ensure no recompilation and graph breaks +torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True) + +# load the model and the processor +processor = AutoProcessor.from_pretrained(model_id) +model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) + +# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead +model.generation_config.max_length = 250 # big enough to avoid recompilation +model.generation_config.max_new_tokens = None # would take precedence over max_length +model.generation_config.cache_implementation = "static" +model.depth_decoder.generation_config.cache_implementation = "static" + +# generation kwargs +gen_kwargs = { + "do_sample": False, + "depth_decoder_do_sample": False, + "temperature": 1.0, + "depth_decoder_temperature": 1.0, +} + +# Define a timing decorator +class TimerContext: + def __init__(self, name="Execution"): + self.name = name + self.start_event = None + self.end_event = None + + def __enter__(self): + # Use CUDA events for more accurate GPU timing + self.start_event = torch.cuda.Event(enable_timing=True) + self.end_event = torch.cuda.Event(enable_timing=True) + self.start_event.record() + return self + + def __exit__(self, *args): + self.end_event.record() + torch.cuda.synchronize() + elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0 + print(f"{self.name} time: {elapsed_time:.4f} seconds") + +# prepare the inputs +ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + +conversation = [ + { + "role": f"{ds[0]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[0]["text"]}, + {"type": "audio", "path": ds[0]["audio"]["array"]}, + ], + }, + { + "role": f"{ds[1]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[1]["text"]}, + {"type": "audio", "path": ds[1]["audio"]["array"]}, + ], + }, + { + "role": f"{ds[2]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[2]["text"]}, + ], + }, +] + +padded_inputs_1 = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, +).to(device) + +print("\n" + "="*50) +print("First generation - compiling and recording CUDA graphs...") +with TimerContext("First generation"): + _ = model.generate(**padded_inputs_1, **gen_kwargs) +print("="*50) + +print("\n" + "="*50) +print("Second generation - fast !!!") +with TimerContext("Second generation"): + _ = model.generate(**padded_inputs_1, **gen_kwargs) +print("="*50) + +# now with different inputs +conversation = [ + { + "role": f"{ds[0]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[2]["text"]}, + {"type": "audio", "path": ds[2]["audio"]["array"]}, + ], + }, + { + "role": f"{ds[1]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[3]["text"]}, + {"type": "audio", "path": ds[3]["audio"]["array"]}, + ], + }, + { + "role": f"{ds[2]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[4]["text"]}, + ], + }, +] +padded_inputs_2 = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, +).to(device) + +print("\n" + "="*50) +print("Generation with other inputs!") +with TimerContext("Generation with different inputs"): + _ = model.generate(**padded_inputs_2, **gen_kwargs) +print("="*50) +``` + +### Training + +CSM Transformers integration supports training! + +```python +from transformers import CsmForConditionalGeneration, AutoProcessor +from datasets import load_dataset, Audio + +model_id = "eustlb/csm-1b" +device = "cuda" + +# load the model and the processor +processor = AutoProcessor.from_pretrained(model_id) +model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) +model.train() + +ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") +# ensure the audio is 24kHz +ds = ds.cast_column("audio", Audio(sampling_rate=24000)) +conversation = [] + +# context +for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): + conversation.append( + { + "role": f"{speaker_id}", + "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], + } + ) + +inputs = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + output_labels=True, +).to(device) + +out = model(**inputs) +out.loss.backward() +``` + +This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb). +The original code can be found [here](https://github.com/SesameAILabs/csm). + + +## CsmConfig + +[[autodoc]] CsmConfig + +## CsmDepthDecoderConfig + +[[autodoc]] CsmDepthDecoderConfig + +## CsmProcessor + +[[autodoc]] CsmProcessor + - __call__ + +## CsmForConditionalGeneration + +[[autodoc]] CsmForConditionalGeneration + - forward + - generate + +## CsmDepthDecoderForCausalLM + +[[autodoc]] CsmDepthDecoderForCausalLM + +## CsmDepthDecoderModel + +[[autodoc]] CsmDepthDecoderModel + +## CsmBackboneModel + +[[autodoc]] CsmBackboneModel diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index fced0e5758..e980d4cbef 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -24,7 +24,12 @@ from typing import List, Optional, Tuple, Union import numpy as np import requests -from .utils import is_librosa_available, requires_backends +from .utils import ( + is_librosa_available, + is_numpy_array, + is_torch_tensor, + requires_backends, +) if is_librosa_available(): @@ -69,6 +74,36 @@ AudioInput = Union[ ] +def is_valid_audio(audio): + return is_numpy_array(audio) or is_torch_tensor(audio) + + +def is_valid_list_of_audio(audio): + return audio and all(is_valid_audio(audio_i) for audio_i in audio) + + +def make_list_of_audio( + audio: Union[list[AudioInput], AudioInput], +) -> AudioInput: + """ + Ensure that the output is a list of audio. + Args: + audio (`Union[List[AudioInput], AudioInput]`): + The input audio. + Returns: + list: A list of audio. + """ + # If it's a list of audios, it's already in the right format + if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio): + return audio + + # If it's a single audio, convert it to a list of + if is_valid_audio(audio): + return [audio] + + raise ValueError("Invalid input type. Must be a single audio or a list of audio") + + def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]: """ Convert frequency from hertz to mels. diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 4627aeb970..6e47b041ab 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -73,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: - cur_len = input_ids.shape[-1] + cur_len = input_ids.shape[1] is_done = cur_len >= self.max_length if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: logger.warning_once( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 36e65949a4..3c7e83369a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -563,7 +563,7 @@ class GenerationMixin: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape else: - batch_size, sequence_length = model_inputs[input_ids_key].shape + batch_size, sequence_length = model_inputs[input_ids_key].shape[:2] # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder. @@ -1708,7 +1708,7 @@ class GenerationMixin: return generation_config, model_kwargs - def _get_initial_cache_position(self, input_ids, model_kwargs): + def _get_initial_cache_position(self, seq_length, device, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: @@ -1718,7 +1718,7 @@ class GenerationMixin: torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 ) else: - cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1 past_length = 0 if model_kwargs.get("past_key_values") is not None: @@ -2332,7 +2332,7 @@ class GenerationMixin: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] + input_ids_length = input_ids.shape[1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( @@ -2805,9 +2805,9 @@ class GenerationMixin: decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # keep track of which sequences are already finished - batch_size = input_ids.shape[0] + batch_size, cur_length = input_ids.shape[:2] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs) this_peer_finished = False @@ -3016,9 +3016,9 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size = input_ids.shape[0] + batch_size, cur_len = input_ids.shape[:2] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) # Create cosine_matrix_mask based on the attention_mask cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) @@ -3428,10 +3428,10 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape + batch_size, cur_len = input_ids.shape[:2] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) model_forward = self.__call__ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) @@ -3834,7 +3834,7 @@ class GenerationMixin: num_beams = generation_config.num_beams num_return_sequences = generation_config.num_return_sequences - batch_size_unflattened, cur_len = input_ids.shape + batch_size_unflattened, cur_len = input_ids.shape[:2] batch_size = batch_size_unflattened // num_beams # TODO (joao): standardize special cases if self.__class__.__name__ == "MoshiDepthDecoder": @@ -3857,7 +3857,7 @@ class GenerationMixin: dim=0, ).to(input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) # (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there # are newer low-memory alternatives like the offloaded cache) @@ -4156,7 +4156,7 @@ class GenerationMixin: device = input_ids.device batch_beam_size, cur_len = input_ids.shape - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) if return_dict_in_generate and output_scores: beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] @@ -4190,7 +4190,7 @@ class GenerationMixin: this_peer_finished = False - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) @@ -4444,8 +4444,8 @@ class GenerationMixin: batch_size = len(constrained_beam_scorer._beam_hyps) num_beams = constrained_beam_scorer.num_beams - batch_beam_size, cur_len = input_ids.shape - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + batch_beam_size, cur_len = input_ids.shape[:2] + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) if num_beams * batch_size != batch_beam_size: raise ValueError( @@ -4477,7 +4477,7 @@ class GenerationMixin: this_peer_finished = False - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -4698,14 +4698,14 @@ class GenerationMixin: ) # keep track of which sequences are already finished - batch_size = input_ids.shape[0] + batch_size, cur_len = input_ids.shape[:2] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) this_peer_finished = False is_first_iteration = True # to preserve the same API in the output as other generation methods while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - cur_len = input_ids.shape[-1] + cur_len = input_ids.shape[1] # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) @@ -4795,7 +4795,7 @@ class GenerationMixin: input_ids = torch.cat((input_ids, valid_tokens), dim=-1) if streamer is not None: streamer.put(valid_tokens.cpu()) - new_cur_len = input_ids.shape[-1] + new_cur_len = input_ids.shape[1] # 4.2. Discard past key values relative to unused assistant tokens new_cache_size = new_cur_len - 1 diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 12da2f3d7a..aad42d3fd5 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -158,4 +158,5 @@ LOSS_MAPPING = { "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss, "DFineForObjectDetection": DFineForObjectDetectionLoss, + "CsmForConditionalGeneration": ForCausalLMLoss, } diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 1be149faf6..8d713b482b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -68,6 +68,7 @@ if TYPE_CHECKING: from .convnextv2 import * from .cpm import * from .cpmant import * + from .csm import * from .ctrl import * from .cvt import * from .d_fine import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 66ed95acad..01c58a5062 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -80,6 +80,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("convnext", "ConvNextConfig"), ("convnextv2", "ConvNextV2Config"), ("cpmant", "CpmAntConfig"), + ("csm", "CsmConfig"), ("ctrl", "CTRLConfig"), ("cvt", "CvtConfig"), ("d_fine", "DFineConfig"), @@ -437,6 +438,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("convnextv2", "ConvNeXTV2"), ("cpm", "CPM"), ("cpmant", "CPM-Ant"), + ("csm", "CSM"), ("ctrl", "CTRL"), ("cvt", "CvT"), ("d_fine", "D-FINE"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b8ec5c6ccb..b196a7718f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -78,6 +78,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("convnext", "ConvNextModel"), ("convnextv2", "ConvNextV2Model"), ("cpmant", "CpmAntModel"), + ("csm", "CsmForConditionalGeneration"), ("ctrl", "CTRLModel"), ("cvt", "CvtModel"), ("d_fine", "DFineModel"), @@ -1446,6 +1447,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( [ # Model for Text-To-Waveform mapping ("bark", "BarkModel"), + ("csm", "CsmForConditionalGeneration"), ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"), ("musicgen", "MusicgenForConditionalGeneration"), ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"), diff --git a/src/transformers/models/csm/__init__.py b/src/transformers/models/csm/__init__.py new file mode 100644 index 0000000000..59468442b5 --- /dev/null +++ b/src/transformers/models/csm/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_csm import * + from .modeling_csm import * + from .processing_csm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/csm/configuration_csm.py b/src/transformers/models/csm/configuration_csm.py new file mode 100644 index 0000000000..e6d6d2e27c --- /dev/null +++ b/src/transformers/models/csm/configuration_csm.py @@ -0,0 +1,440 @@ +# coding=utf-8 +# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class CsmDepthDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CsmDepthDecoderModel`]. It is used to instantiate an CSM depth decoder + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of the csm-1b. + + e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_codebooks (`int`, *optional*, defaults to 32): + Number of codebooks used in the underlying codec model responsible for tokenizing the audio. + backbone_hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations of the backbone model used with this depth decoder. + vocab_size (`int`, *optional*, defaults to 2051): + Vocabulary size of the CsmDepthDecoder model. Defines the number of different audio tokens that can be represented by each codebook. + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 4): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 33): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 2050): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*): + End of stream token id. + rope_theta (`float`, *optional*, defaults to 500000): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + + ```python + >>> from transformers import CsmDepthDecoder, CsmDepthDecoderConfig + + >>> # Initializing a CsmDepthDecoder + >>> configuration = CsmDepthDecoderConfig() + >>> model = CsmDepthDecoderModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "csm_depth_decoder_model" + base_config_key = "depth_decoder_config" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + num_codebooks=32, + backbone_hidden_size=2048, + vocab_size=2051, + hidden_size=1024, + intermediate_size=8192, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=33, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + rope_theta=500000, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + **kwargs, + ): + if kwargs.pop("tie_word_embeddings", False): + raise ValueError("`tie_word_embeddings=True` is not supported for CsmDepthDecoderConfig") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=False, + **kwargs, + ) + self.num_codebooks = num_codebooks + self.vocab_size = vocab_size + self.backbone_hidden_size = backbone_hidden_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + +class CsmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CsmForConditionalGeneration`]. It is used to instantiate an CSM + model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the csm-1b. + + e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_codebooks (`int`, *optional*, defaults to 32): + Number of codebooks used in the underlying codec model responsible for tokenizing the audio. + vocab_size (`int`, *optional*, defaults to 2051): + Vocabulary size of the Csm model. Defines the number of different audio tokens that can be represented by each codebook. + text_vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the text input for the Csm model. Defines the number of different text tokens that can be represented. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations of the backbone model. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations of the backbone model. + num_hidden_layers (`int`, *optional*, defaults to 16): + Number of hidden layers in the backbone model Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the backbone model Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the backbone model Transformer decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 128002): + Padding token id. + codebook_pad_token_id (`int`, *optional*, defaults to 2050): + Padding token id for codebook tokens. + codebook_eos_token_id (`int`, *optional*, defaults to 0): + End of stream token id for codebook tokens. + bos_token_id (`int`, *optional*, defaults to 128000): + Beginning of stream token id. + eos_token_id (`int`, *optional*): + End of stream token id. + audio_token_id (`int`, *optional*, defaults to 128002): + Audio token id in the text input. + audio_eos_token_id (`int`, *optional*, defaults to 128003): + End of stream token id for audio in the text input. + rope_theta (`float`, *optional*, defaults to 500000): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*, defaults to `{'factor': 32.0, 'high_freq_factor': 0.5, 'low_freq_factor': 0.125, 'original_max_position_embeddings': 1024, 'rope_type': 'llama3'}`): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + tie_codebooks_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie the codebook tokens embeddings of the backbone model to the codebook tokens embeddings of the depth decoder. + depth_decoder_config (`CsmDepthDecoderConfig`, *optional*): + Configuration for the depth decoder. + codec_config (`PretrainedConfig`, *optional*): + Configuration for the codec. + + ```python + >>> from transformers import CsmForConditionalGeneration, CsmConfig + + >>> # Initializing a CsmConfig + >>> configuration = CsmConfig() + + >>> # Initializing a model + >>> model = CsmForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "csm" + base_config_key = "csm_config" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = { + "codec_config": AutoConfig, + "depth_decoder_config": CsmDepthDecoderConfig, + } + + def __init__( + self, + num_codebooks=32, + vocab_size=2051, + text_vocab_size=128256, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=16, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=128002, + codebook_pad_token_id=2050, + codebook_eos_token_id=0, + bos_token_id=128000, + eos_token_id=None, + audio_token_id=128002, + audio_eos_token_id=128003, + rope_theta=500000, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + tie_codebooks_embeddings=True, + depth_decoder_config=None, + codec_config=None, + **kwargs, + ): + if kwargs.pop("tie_word_embeddings", False): + raise ValueError("`tie_word_embeddings=True` is not supported for CsmConfig") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=False, + **kwargs, + ) + + if depth_decoder_config is None: + self.depth_decoder_config = CsmDepthDecoderConfig() + logger.info("depth_decoder_config is None, using default depth decoder config.") + elif isinstance(depth_decoder_config, dict): + self.depth_decoder_config = CsmDepthDecoderConfig(**depth_decoder_config) + elif isinstance(depth_decoder_config, CsmDepthDecoderConfig): + self.depth_decoder_config = depth_decoder_config + + if codec_config is None: + self.codec_config = AutoConfig.for_model("mimi") + logger.info("codec_config is None, using default audio encoder config.") + elif isinstance(codec_config, dict): + self.codec_config = AutoConfig.for_model(**codec_config) + elif isinstance(codec_config, PretrainedConfig): + self.codec_config = codec_config + + self.text_vocab_size = text_vocab_size + self.num_codebooks = num_codebooks + self.audio_token_id = audio_token_id + self.audio_eos_token_id = audio_eos_token_id + self.codebook_pad_token_id = codebook_pad_token_id + self.codebook_eos_token_id = codebook_eos_token_id + self.tie_codebooks_embeddings = tie_codebooks_embeddings + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + +__all__ = [ + "CsmDepthDecoderConfig", + "CsmConfig", +] diff --git a/src/transformers/models/csm/convert_csm.py b/src/transformers/models/csm/convert_csm.py new file mode 100644 index 0000000000..dc84e2cf3d --- /dev/null +++ b/src/transformers/models/csm/convert_csm.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import gc +import os +import re + +import torch +from tokenizers.processors import TemplateProcessing + +from transformers import ( + AutoFeatureExtractor, + AutoTokenizer, + CsmConfig, + CsmDepthDecoderConfig, + CsmForConditionalGeneration, + CsmProcessor, + MimiModel, +) +from transformers.utils.hub import cached_file + + +# fmt: off +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"backbone\.layers\.(\d+)": r"backbone_model.layers.\1", + r"decoder\.layers\.(\d+)": r"depth_decoder.model.layers.\1", + + r"attn": r"self_attn", + r"output_proj": r"o_proj", + r"w1": r"gate_proj", + r"w2": r"down_proj", + r"w3": r"up_proj", + + r"text_embeddings": r"embed_text_tokens", + r"audio_embeddings": r"backbone_model.embed_tokens.embed_audio_tokens", + + r"codebook0_head": r"lm_head", + r"audio_head": r"depth_decoder.codebooks_head.weight", + r"projection": r"depth_decoder.model.inputs_embeds_projector", + + r"sa_norm.scale": r"input_layernorm.weight", + r"mlp_norm.scale": r"post_attention_layernorm.weight", + r"decoder.norm.scale": r"depth_decoder.model.norm.weight", + r"backbone.norm.scale": r"backbone_model.norm.weight", +} +# fmt: on + + +def permute_for_rope(input_tensor, n_heads, dim1, dim2): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + input_tensor = input_tensor.reshape(dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2) + return input_tensor + + +def convert_key(key, mapping): + for pattern, replacement in mapping.items(): + key = re.sub(pattern, replacement, key) + return key + + +def write_model( + input_path_or_repo, + model_name, + codec_model_path_or_repo, + output_dir, + safe_serialization=True, +): + print("Converting the model.") + os.makedirs(output_dir, exist_ok=True) + + codec_model = MimiModel.from_pretrained(codec_model_path_or_repo) + codec_model.config._attn_implementation_autoset = False + + # prepare rope scaling args: the model uses originally + # 1 - for the depth decoder + # rope_theta=500000, + # rope_scaling={ + # "factor": 32.0, + # "high_freq_factor": 4.0, + # "low_freq_factor": 1.0, + # "original_max_position_embeddings": 8192, + # "rope_type": "llama3", + # }, + # 2 - for the backbone + # rope_theta=500000, + # rope_scaling={ + # "factor": 32.0, + # "high_freq_factor": 4.0, + # "low_freq_factor": 1.0, + # "original_max_position_embeddings": 8192, + # "rope_type": "llama3", + # }, + # + # Yet we want to use max_position_embeddings=32, resp. 2048 + # This will throw warning as we would have original_max_position_embeddings >= max_position_embeddings + # Therefore, we convert values to equivalent ones + + depth_decoder_config = CsmDepthDecoderConfig( + rope_scaling={ + "factor": 32.0, + "high_freq_factor": 0.0078125, + "low_freq_factor": 0.001953125, + "original_max_position_embeddings": 16, + "rope_type": "llama3", + }, + ) + + config = CsmConfig( + codec_config=codec_model.config, + depth_decoder_config=depth_decoder_config, + rope_scaling={ + "factor": 32.0, + "high_freq_factor": 0.5, + "low_freq_factor": 0.125, + "original_max_position_embeddings": 1024, + "rope_type": "llama3", + }, + ) + + params = { + "backbone": { + "num_attention_heads": config.num_attention_heads, + "num_key_value_heads": config.num_key_value_heads, + "dim_per_head": config.head_dim, + "key_value_dim": config.head_dim * config.num_key_value_heads, + "dim": config.hidden_size, + }, + "depth_decoder": { + "num_attention_heads": config.depth_decoder_config.num_attention_heads, + "num_key_value_heads": config.depth_decoder_config.num_key_value_heads, + "dim_per_head": config.depth_decoder_config.head_dim, + "key_value_dim": config.depth_decoder_config.head_dim * config.depth_decoder_config.num_key_value_heads, + "dim": config.depth_decoder_config.hidden_size, + }, + } + + model_path = cached_file( + input_path_or_repo, + model_name, + ) + print(f"Fetching all parameters from the checkpoint at {model_path}...") + loaded = torch.load(model_path, map_location="cpu") + + print("Converting model...") + state_dict = {} + + # ----------------------- + # convert parameter names + # ----------------------- + + # Add codec_model. prefix to every key in the codec model state dict + codec_state_dict = {f"codec_model.{k}": v for k, v in codec_model.state_dict().items()} + state_dict.update(codec_state_dict) + + for key, value in loaded.items(): + new_key = convert_key(key, ORIGINAL_TO_CONVERTED_KEY_MAPPING) + current_parameter = value + + # Post-process the current_parameter. + if re.search("(k|q)_proj.weight", new_key): + params_keys = "backbone" if "backbone" in new_key else "depth_decoder" + if "q_proj" in new_key: + num_heads = params[params_keys]["num_attention_heads"] + dim_per_head = params[params_keys]["dim_per_head"] + param_dim = params[params_keys]["dim"] + dim = params[params_keys]["dim"] + else: + num_heads = params[params_keys]["num_key_value_heads"] + dim_per_head = params[params_keys]["dim_per_head"] + param_dim = params[params_keys]["key_value_dim"] + dim = params[params_keys]["dim"] + + current_parameter = permute_for_rope(value, num_heads, param_dim, dim) + state_dict[new_key] = current_parameter.reshape(num_heads * dim_per_head, dim) + + state_dict[new_key] = current_parameter + + # add the depth decoder embed audio tokens weights, latter tied to the backbone embed audio tokens weights + state_dict["depth_decoder.model.embed_tokens.weight"] = state_dict[ + "backbone_model.embed_tokens.embed_audio_tokens.weight" + ].clone() + del loaded + gc.collect() + + # ------------------------- + # load the weights and save + # ------------------------- + + print("Loading the checkpoint in a Csm model.") + with torch.device("meta"): + model = CsmForConditionalGeneration(config) + model.load_state_dict(state_dict, strict=True, assign=True) + print("Checkpoint loaded successfully.") + del model.config._name_or_path + + # default generation config + model.generation_config._from_model_config = False + model.generation_config.max_new_tokens = 125 + model.generation_config.do_sample = True + model.generation_config.top_k = 50 + model.generation_config.temperature = 0.9 + model.generation_config.depth_decoder_do_sample = True + model.generation_config.depth_decoder_top_k = 50 + model.generation_config.depth_decoder_temperature = 0.9 + + print("Saving the model.") + model.save_pretrained(output_dir, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + CsmForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + print("Model reloaded successfully.") + + +def write_tokenizer(output_dir): + # from https://github.com/SesameAILabs/csm/blob/2d720827843b653c4d67bb4445b1c0a4f59e646f/generator.py#L22-L36 + def load_llama3_tokenizer(): + """ + https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992 + """ + tokenizer_name = "meta-llama/Llama-3.2-1B" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + bos = tokenizer.bos_token + eos = tokenizer.eos_token + tokenizer._tokenizer.post_processor = TemplateProcessing( + single=f"{bos}:0 $A:0 {eos}:0", + pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", + special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], + ) + + return tokenizer + + tokenizer = load_llama3_tokenizer() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.save_pretrained(output_dir) + + # manually modify in tokenizer_config.json + # "128002": { + # "content": "<|AUDIO|>", + # ... + # } + # "128003": { + # "content": "<|audio_eos|>", + # ... + # } + print( + "Tokenizer saved successfully. Please manually modify in tokenizer_config.json AND tokenizer.json as follows: " + ) + print(""" + # "128002": { + # "content": "<|AUDIO|>", + # ... + # } + # "128003": { + # "content": "<|audio_eos|>", + # ... + # } + """) + + +def write_processor(output_dir, codec_model_path_or_repo): + chat_template = "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n" + tokenizer = AutoTokenizer.from_pretrained(output_dir) + feature_extractor = AutoFeatureExtractor.from_pretrained(codec_model_path_or_repo) + + processor = CsmProcessor( + tokenizer=tokenizer, + feature_extractor=feature_extractor, + chat_template=chat_template, + ) + + processor.save_pretrained(output_dir) + print("Processor saved successfully.") + + +def main(): + parser = argparse.ArgumentParser(description="Convert Csm weights to HuggingFace format") + parser.add_argument( + "--input_path_or_repo", + type=str, + required=True, + help="Path or repo containing Csm weights", + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Name of the model in input_path_or_repo", + ) + parser.add_argument( + "--codec_model_path_or_repo", + type=str, + required=True, + help="Path or repo containing the codec model", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`." + ) + args = parser.parse_args() + + write_model( + args.input_path_or_repo, + args.model_name, + args.codec_model_path_or_repo, + output_dir=args.output_dir, + safe_serialization=args.safe_serialization, + ) + + write_tokenizer(args.output_dir) + + write_processor(args.output_dir, args.codec_model_path_or_repo) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/csm/generation_csm.py b/src/transformers/models/csm/generation_csm.py new file mode 100644 index 0000000000..b1c2cd920d --- /dev/null +++ b/src/transformers/models/csm/generation_csm.py @@ -0,0 +1,491 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...generation import ( + GenerateDecoderOnlyOutput, + GenerationConfig, + GenerationMixin, + GenerationMode, +) +from ...generation.logits_process import LogitsProcessorList +from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList +from ...generation.utils import GenerateNonBeamOutput +from ...utils import logging + + +if TYPE_CHECKING: + from ...generation.streamers import BaseStreamer + + +logger = logging.get_logger(__name__) + + +@dataclass +class CsmGenerateOutput(GenerateDecoderOnlyOutput): + """ + Outputs of CsmForConditionalGeneration.generate. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + audio (`list(torch.FloatTensor)` of length `batch_size`): + The generated audio. + """ + + audio: Optional[List[torch.Tensor]] = None + + +class CsmGenerationMixin(GenerationMixin): + def _get_stopping_criteria( + self, + *args, + **kwargs, + ) -> StoppingCriteriaList: + criteria = super()._get_stopping_criteria(*args, **kwargs) + + kept_criteria = StoppingCriteriaList() + for criterion in criteria: + if not isinstance(criterion, MaxLengthCriteria): + logger.warning( + f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored." + ) + else: + kept_criteria.append(criterion) + return kept_criteria + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + """ + This method overrides [~generation.utils.GenerationMixin._prepare_generation_config]. + It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled. + """ + # extract depth decoder kwargs and remove them from the main kwargs + depth_decoder_kwargs = { + k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_") + } + + # remove the depth decoder keys from the original kwargs + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")} + + # initialize the generation config + generation_config, model_kwargs = super()._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) + self.depth_decoder.generation_config.update(**depth_decoder_kwargs) + + # ensure the depth decoder generation config is valid + depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or ( + self.config.num_codebooks - 1 + ) + depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or ( + self.config.num_codebooks - 1 + ) + + if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}: + raise ValueError( + f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})" + ) + elif self.depth_decoder.generation_config.return_dict_in_generate: + logger.warning( + "depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate" + ) + self.depth_decoder.generation_config.return_dict_in_generate = False + + self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens + self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens + + # Monkey patch the get_generation_mode method to support CSM model + original_get_generation_mode = generation_config.get_generation_mode + + def patched_get_generation_mode(assistant_model=None): + generation_mode = original_get_generation_mode(assistant_model) + if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]: + raise ValueError( + f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation." + ) + + return generation_mode + + generation_config.get_generation_mode = patched_get_generation_mode + + return generation_config, model_kwargs + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + """ + This method overrides [~generation.utils.GenerationMixin._sample]. + To ease maintenance, modifications are marked with the comment "Csm specific". + + Indeed, Csm model requires a custom generation sampling step: + 1. Infer the backbone model to sample the first codebook token + 2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens + 3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model + 4. Repeat until stopping criteria is met + + Csm supports two stopping criterias: + - stop when the generated sequence is at max_length + - stop when all the generated codebook tokens are the codebook_eos_token_id + """ + # init values + # *************** Csm specific *************** + pad_token_id = self.config.codebook_pad_token_id + has_eos_stopping_criteria = generation_config._eos_token_tensor is not None + # ============================================ + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape[:2] + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) + + # *************** Csm specific *************** + if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None: + # in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids, + # we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned + for criterion in stopping_criteria: + if isinstance(criterion, MaxLengthCriteria): + criterion.max_length -= cur_len + # ============================================ + + model_forward = self.__call__ + compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) + if compile_forward: + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + is_prefill = True + while self._has_unfinished_sequences( + this_peer_finished, + synced_gpus, + device=input_ids.device, + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + # *************** Csm specific *************** + model_inputs.update({"output_hidden_states": True}) + # ============================================ + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + ) + if synced_gpus and this_peer_finished: + continue + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += (outputs.attentions,) + + if output_hidden_states: + decoder_hidden_states += (outputs.hidden_states,) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # *************** Csm specific *************** + # infer the depth decoder + first_codebook_ids = next_tokens[:, None] + # adds place holder in position 0 that will be replaced by the backbone_last_hidden_state + depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0) + backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :] + + depth_decoder_outputs = self.depth_decoder.generate( + input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone() + ) + codebook_ids = ( + depth_decoder_outputs + if isinstance(depth_decoder_outputs, torch.Tensor) + else depth_decoder_outputs.sequences + ) + # remove the place holder in position 0 + codebook_ids = codebook_ids[:, 1:] + next_tokens = codebook_ids + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * ( + 1 - unfinished_sequences.unsqueeze(-1) + ) + + # update generated ids, model inputs, and length for next step + if input_ids.ndim == 2: + input_ids = next_tokens[:, None, :] + else: + input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1) + # ============================================ + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + # *************** Csm specific *************** + # for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!! + unfinished_sequences = unfinished_sequences & ~( + input_ids[:, -1, :-1] == self.config.codebook_eos_token_id + ).all(-1) + # ============================================ + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + # *************** Csm specific *************** + del depth_decoder_outputs + # ============================================ + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + input_values: Optional[torch.Tensor] = None, + input_values_cutoffs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + synced_gpus: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, + output_audio: Optional[bool] = False, + **kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model. + Indeed, Csm model requires a custom generation sampling step: + 1. Infer the backbone model to sample the first codebook token + 2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens + 3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model + 4. Repeat until stopping criteria is met + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, do_sample=True)`. + + + Parameters: + inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*): + The sequence used as a prompt for the backbone model. + input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*): + The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry. + These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`. + input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*): + Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input. + If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences + where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2, + the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]]. + generation_config ([`~generation.GenerationConfig`], *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complements the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. If your stopping criteria depends on the `scores` input, make + sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is + intended for advanced users. + synced_gpus (`bool`, *optional*): + Whether to continue running the while loop until max_length. Unless overridden, this flag will be set + to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid + deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + output_audio (`bool`, *optional*): + Whether to return the generated audio. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*. + + Return: + [`CsmGenerateOutput`] or `torch.LongTensor` or `List[torch.FloatTensor]`: A [`CsmGenerateOutput`] + (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False` + or a `List[torch.FloatTensor]` otherwise. + + Example: + + ```python + >>> from transformers import CsmProcessor, CsmForConditionalGeneration + >>> from datasets import load_dataset, Audio + + >>> model_id = "eustlb/csm-1b" + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = AutoProcessor.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + >>> # ensure the audio is 24kHz + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + + >>> conversation = [] + >>> # prepare a conversation with text and corresponding audio + >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): + ... conversation.append( + ... { + ... "role": f"{speaker_id}", + ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], + ... } + ... ) + + >>> # text prompt + >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]}) + + >>> inputs = processor.apply_chat_template( + ... conversation, + ... tokenize=True, + ... return_dict=True, + ... ).to(torch_device) + + >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + >>> audio = model.generate(**inputs, output_audio=True) + >>> processor.save_audio(audio, "output.wav") + ``` + """ + generate_output = super().generate( + input_ids=input_ids, + input_values=input_values, + input_values_cutoffs=input_values_cutoffs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + synced_gpus=synced_gpus, + streamer=streamer, + **kwargs, + ) + + generate_returned_dict = not isinstance(generate_output, torch.Tensor) + audio = None + if output_audio: + generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output + + # infer the codec model + audio = [] + with torch.no_grad(): + # ======================================= + # TODO: @eustlb, this should be batched !!! + # but requires making sure batched inference of the codec model works as intended + for audio_codes_batch in generated_audio_codes: + eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero() + if eos_idxs.numel() != 0: + cutoff_idx = eos_idxs.min() + else: + cutoff_idx = audio_codes_batch.shape[1] + + audio_codes_batch = audio_codes_batch[:cutoff_idx] + codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0)) + audio.append(codec_decode_output.audio_values[0, 0]) + # ======================================= + + if generate_returned_dict: + return CsmGenerateOutput(audio=audio, **generate_output) + elif output_audio: + return audio + else: + return generate_output diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py new file mode 100644 index 0000000000..03cbc07df4 --- /dev/null +++ b/src/transformers/models/csm/modeling_csm.py @@ -0,0 +1,1710 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/csm/modular_csm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_csm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel +from .configuration_csm import CsmConfig, CsmDepthDecoderConfig +from .generation_csm import CsmGenerationMixin + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "CsmConfig" + + +@dataclass +class CsmOutputWithPast(ModelOutput): + """ + Base class for the model autoregressive outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction) of the depth decoder model. + depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax). + depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction) of the backbone model. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_decoder_loss: Optional[torch.FloatTensor] = None + depth_decoder_logits: torch.FloatTensor = None + depth_decoder_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + depth_decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + backbone_loss: Optional[torch.FloatTensor] = None + + +START_DOCSTRING_BASE = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`{config_class}`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +CSM_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmConfig")) + + +@add_start_docstrings( + "The bare Csm Model outputting raw hidden-states without any specific head on top.", + CSM_START_DOCSTRING, +) +class CsmPreTrainedModel(PreTrainedModel): + config_class = CsmConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["CsmDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + # does not because of Mimi codec model + # _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, CsmCodebooksHead): + num_codebooks = module.num_codebooks + for i in range(num_codebooks - 1): + module.weight.data[i].normal_(mean=0.0, std=std) + elif isinstance(module, CsmRMSNorm): + module.weight.data.fill_(1.0) + + +@use_kernel_forward_from_hub("RMSNorm") +class CsmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + CsmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class CsmRotaryEmbedding(nn.Module): + def __init__(self, config: CsmConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class CsmMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class CsmAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: CsmConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class CsmDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: CsmConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CsmAttention(config=config, layer_idx=layer_idx) + + self.mlp = CsmMLP(config) + self.input_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +CSM_DEPTH_DECODER_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmDepthDecoderConfig")) + + +INPUTS_DOCSTRING_BASE = r""" + Args: + {input_ids_docstring} + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +DEPTH_DECODER_INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids)""" + + +CSM_DEPTH_DECODER_INPUTS_DOCSTRING = r"""{}""".format( + INPUTS_DOCSTRING_BASE.format(input_ids_docstring=DEPTH_DECODER_INPUT_IDS_DOCSTRING) +) + + +@add_start_docstrings( + "The bare CsmDepthDecoderModel outputting raw hidden-states without any specific head on top.", + CSM_DEPTH_DECODER_START_DOCSTRING, +) +class CsmDepthDecoderModel(CsmPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`] + + Args: + config: CsmDepthDecoderConfig + """ + + config_class = CsmDepthDecoderConfig + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size) + self.layers = nn.ModuleList( + [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = CsmRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + backbone_last_hidden_state: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*): + The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model) + is provided in the `input_ids` argument. + """ + if position_ids is not None and not torch.compiler.is_compiling(): + logger.warning_once( + "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids " + "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored." + ) + position_ids = None + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] + device = inputs_embeds.device if inputs_embeds is not None else input_ids.device + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device) + + if inputs_embeds is None: + codebook_idxs = torch.clamp(cache_position - 1, min=0) + offset = codebook_idxs * self.vocab_size + inputs_embeds = self.embed_tokens(input_ids + offset) + + input_ids_are_first_codebook = cache_position[0] == 0 + if backbone_last_hidden_state is not None: + inputs_embeds[:, 0] = backbone_last_hidden_state + else: + if not torch.compiler.is_compiling() and input_ids_are_first_codebook: + logger.warning( + "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference." + ) + + inputs_embeds = self.inputs_embeds_projector(inputs_embeds) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_ids = cache_position.unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class CsmCodebooksHead(nn.Module): + def __init__(self, hidden_size, num_codebooks, vocab_size): + super().__init__() + self.num_codebooks = num_codebooks + self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size)) + + def forward(self, hidden_states, cache_position=None): + if cache_position is None: + seq_length = hidden_states.shape[1] + codebook_weight = self.weight[torch.arange(seq_length)] + else: + codebook_idxs = cache_position - 1 + codebook_weight = self.weight[codebook_idxs] + + hidden_states = [ + nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T) + for codebook_idx in range(codebook_weight.shape[0]) + ] + hidden_states = torch.stack(hidden_states, dim=1) + + return hidden_states + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@add_start_docstrings( + """ + The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top, + which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook + (e.g. position 0 is the first codebook and uses the first codebook head, etc.) + """, + CSM_DEPTH_DECODER_START_DOCSTRING, +) +class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin): + _tied_weights_keys = None + _tp_plan = None + _pp_plan = None + + def __init__(self, config): + super().__init__(config) + self.model = CsmDepthDecoderModel(config) + self.vocab_size = config.vocab_size + self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + backbone_last_hidden_state: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, CsmDepthDecoderForCausalLM + + >>> model = CsmDepthDecoderForCausalLM.from_pretrained("meta-csm_depth_decoder/CsmDepthDecoder-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-csm_depth_decoder/CsmDepthDecoder-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ``` + backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*): + The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model) + is provided in the `input_ids` argument. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + backbone_last_hidden_state=backbone_last_hidden_state, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if isinstance(logits_to_keep, int): + if logits_to_keep == 0: + # skip idx 0 logits since it's for the concatenated backbone last hidden state + slice_indices = slice(1, None) + else: + slice_indices = slice(-logits_to_keep, None) + else: + slice_indices = logits_to_keep + + logits = self.codebooks_head( + hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None + ) + logits = logits.contiguous() + + loss = None + if labels is not None: + shift_labels = labels[..., 1:].contiguous() + loss = self.loss_function( + logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs + ) + + is_first_generation_step = model_inputs["cache_position"][0] == 0 + if not is_first_generation_step: + model_inputs.pop("backbone_last_hidden_state") + + # csm depth decoder does not use position_ids + model_inputs.pop("position_ids") + + return model_inputs + + +class CsmBackboneModelEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size) + self.register_buffer( + "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False + ) + + def forward(self, input_ids): + input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets) + input_embeds = input_embeds.sum(dim=2) + return input_embeds + + +INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`): + 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input + requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens. + + 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids)""" + + +CSM_BACKBONE_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING)) + + +@add_start_docstrings( + "The bare CsmBackboneModel Model outputting raw hidden-states without any specific head on top.", + CSM_START_DOCSTRING, +) +class CsmBackboneModel(CsmPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`] + + Args: + config: CsmBackboneModelConfig + """ + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = CsmBackboneModelEmbeddings(config) + self.layers = nn.ModuleList( + [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = CsmRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_BACKBONE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +CSM_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING)) + + +@add_start_docstrings( + """ + The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens. + """, + CSM_START_DOCSTRING, +) +class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): + _tied_weights_keys = [ + "backbone_model.embed_tokens.embed_audio_tokens.weight", + "depth_decoder.model.embed_tokens.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size) + self.backbone_model = CsmBackboneModel._from_config(config) + self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config) + self.codec_model = AutoModel.from_config(config.codec_config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone_model.embed_tokens + + def set_input_embeddings(self, value): + self.backbone_model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _tie_weights(self): + if self.config.tie_codebooks_embeddings: + self._tie_or_clone_weights( + self.backbone_model.embed_tokens.embed_audio_tokens, + self.depth_decoder.model.embed_tokens, + ) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get("output_loading_info", False): + model, loading_info = super().from_pretrained(*args, **kwargs) + else: + model = super().from_pretrained(*args, **kwargs) + + # copy depth decoder generation conf attr to the depth decoder generation config + prefix = "depth_decoder_" + prefix_len = len(prefix) + depth_decoder_attrs = { + attr[prefix_len:]: value + for attr, value in vars(model.generation_config).items() + if attr.startswith(prefix) + } + + vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs}) + + # remove the depth decoder generation conf attr from the model generation config + for attr in depth_decoder_attrs: + delattr(model.generation_config, prefix + attr) + + if "output_loading_info" in kwargs: + return model, loading_info + else: + return model + + def save_pretrained(self, *args, **kwargs): + # copy the depth decoder generation config attributes to the model generation config + prefix = "depth_decoder_" + depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict() + depth_decoder_attrs.pop("transformers_version", None) + for attr, value in depth_decoder_attrs.items(): + setattr(self.generation_config, prefix + attr, value) + + super().save_pretrained(*args, **kwargs) + + def _merge_input_ids_with_input_values( + self, + input_ids: Optional[torch.Tensor] = None, + input_values: Optional[torch.Tensor] = None, + input_values_cutoffs: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + """ + Merges the input_ids and input_values to produce a single inputs_embeds tensor: + 1 - Infers the codec model on the input_values to retreive codebook token. + 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor. + 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor. + + Args: + input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The input ids to embed. + input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`): + The audio input values to embed. + input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`): + The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio. + """ + inputs_embeds = self.embed_text_tokens(input_ids) + + if input_values is not None: + # infer input_values_mask + input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0)) + audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff() + audio_lengths = audio_lengths[audio_lengths > 0] + input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand( + len(audio_lengths), -1 + ) + input_values_mask = input_values_mask < audio_lengths.unsqueeze(1) + + # ======================================= + # TODO: @eustlb, this should be batched !!! + # but requires making sure batched inference of the codec model works as intended + audio_tokens_list = [] + for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): + batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] + for i in range(batch_input_values_cutoffs.shape[0] - 1): + start_idx = batch_input_values_cutoffs[i] + end_idx = batch_input_values_cutoffs[i + 1] + audio_batch = batch_input_values[..., start_idx:end_idx] + codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) + codebook_ids = codec_outputs.audio_codes.transpose(1, -1) + audio_tokens_list.append(codebook_ids[0]) + + max_audio_frames = max(el.shape[0] for el in audio_tokens_list) + batched_audio_token_ids = torch.stack( + [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] + ) + audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) + # ======================================= + audio_token_id = self.config.audio_token_id + audio_token_mask = input_ids == audio_token_id + + audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids) + inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask] + + # same for the audio eos token + audio_eos_frame_ids = ( + torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long) + * self.config.codebook_eos_token_id + ) + audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1) + + audio_eos_token_mask = input_ids == self.config.audio_eos_token_id + inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1) + + # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks) + if labels is not None: + labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) + labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] + # mask depth decoder + depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) + labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 + labels = labels_expanded + + return {"inputs_embeds": inputs_embeds, "labels": labels} + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None: + merged_inputs = self._merge_input_ids_with_input_values( + input_ids=input_ids, + input_values=kwargs.get("input_values"), + input_values_cutoffs=kwargs.get("input_values_cutoffs"), + labels=kwargs.get("labels"), + ) + model_inputs.update( + {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None} + ) + + return model_inputs + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CsmOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + input_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + input_values_cutoffs: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CsmOutputWithPast]: + r""" + input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*): + Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input. + If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences + where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2, + the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]]. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`. + Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`. + - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames) + - `-100` will be ignored in the loss computation + - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels) + + Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`]. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + Kept for compatibility. Does not support another value than: + 1. `0`, which is equivalent to keeping all logits, used in the training regime + 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import CsmForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset, Audio + + >>> model_id = "eustlb/csm-1b" + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = AutoProcessor.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + >>> # ensure the audio is 24kHz + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + + >>> conversation = [] + >>> # prepare a conversation with text and corresponding audio + >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): + ... conversation.append( + ... { + ... "role": f"{speaker_id}", + ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], + ... } + ... ) + + >>> inputs = processor.apply_chat_template( + ... conversation, + ... tokenize=True, + ... return_dict=True, + ... output_labels=True, + ... ).to(torch_device) + + >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + >>> output = model(**inputs) + >>> output.loss.backward() + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is not None and input_ids.ndim == 2: + merged_inputs = self._merge_input_ids_with_input_values( + input_ids, input_values, input_values_cutoffs, labels + ) + inputs_embeds = merged_inputs["inputs_embeds"] + labels = merged_inputs["labels"] + input_ids = None + + backbone_outputs = self.backbone_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + backbone_hidden_states = backbone_outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :]) + + loss = None + backbone_loss = None + depth_decoder_loss = None + depth_decoder_outputs = None + if labels is not None: + # select first codebook as labels for the backbone model + backbone_labels = labels[:, :, 0] + backbone_loss = self.loss_function( + logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + # for the depth decoder, we need to select the frames to train on + # those are frames where the label is not uniformly `ignore_index` along the codebook dimension + train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1) + depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1] + # add place holder in position 0 that will be replaced by the backbone_last_hidden_state + depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0) + + train_idxs = train_mask.nonzero(as_tuple=True) + backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :] + depth_decoder_labels = labels[train_mask] + + depth_decoder_outputs = self.depth_decoder( + input_ids=depth_decoder_input_ids, + backbone_last_hidden_state=backbone_last_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + labels=depth_decoder_labels, + ) + + depth_decoder_loss = depth_decoder_outputs.loss + loss = backbone_loss + depth_decoder_loss + + return CsmOutputWithPast( + loss=loss, + backbone_loss=backbone_loss, + depth_decoder_loss=depth_decoder_loss, + logits=backbone_logits, + past_key_values=backbone_outputs.past_key_values, + hidden_states=backbone_outputs.hidden_states, + attentions=backbone_outputs.attentions, + depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None, + depth_decoder_past_key_values=depth_decoder_outputs.past_key_values + if depth_decoder_outputs is not None + else None, + depth_decoder_hidden_states=depth_decoder_outputs.hidden_states + if depth_decoder_outputs is not None + else None, + depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, + ) + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = [ + "CsmPreTrainedModel", + "CsmBackboneModel", + "CsmDepthDecoderModel", + "CsmDepthDecoderForCausalLM", + "CsmForConditionalGeneration", +] diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py new file mode 100644 index 0000000000..ed3d571034 --- /dev/null +++ b/src/transformers/models/csm/modular_csm.py @@ -0,0 +1,1042 @@ +# coding=utf-8 +# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from ..auto import AutoModel +from ..llama.modeling_llama import ( + KwargsForCausalLM, + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from .configuration_csm import ( + CsmConfig, + CsmDepthDecoderConfig, +) +from .generation_csm import CsmGenerationMixin + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "CsmConfig" + + +@dataclass +class CsmOutputWithPast(ModelOutput): + """ + Base class for the model autoregressive outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction) of the depth decoder model. + depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax). + depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction) of the backbone model. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_decoder_loss: Optional[torch.FloatTensor] = None + depth_decoder_logits: torch.FloatTensor = None + depth_decoder_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + depth_decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + backbone_loss: Optional[torch.FloatTensor] = None + + +START_DOCSTRING_BASE = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`{config_class}`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +CSM_DEPTH_DECODER_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmDepthDecoderConfig")) + + +CSM_START_DOCSTRING = r"""{}""".format(START_DOCSTRING_BASE.format(config_class="CsmConfig")) + + +@add_start_docstrings( + "The bare Csm Model outputting raw hidden-states without any specific head on top.", + CSM_START_DOCSTRING, +) +class CsmPreTrainedModel(PreTrainedModel): + config_class = CsmConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["CsmDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + # does not because of Mimi codec model + # _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, CsmCodebooksHead): + num_codebooks = module.num_codebooks + for i in range(num_codebooks - 1): + module.weight.data[i].normal_(mean=0.0, std=std) + elif isinstance(module, CsmRMSNorm): + module.weight.data.fill_(1.0) + + +INPUTS_DOCSTRING_BASE = r""" + Args: + {input_ids_docstring} + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +DEPTH_DECODER_INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids)""" + + +INPUT_IDS_DOCSTRING = r"""input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`): + 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input + requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens. + + 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids)""" + + +CSM_DEPTH_DECODER_INPUTS_DOCSTRING = r"""{}""".format( + INPUTS_DOCSTRING_BASE.format(input_ids_docstring=DEPTH_DECODER_INPUT_IDS_DOCSTRING) +) + + +CSM_BACKBONE_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING)) + + +# manually specify names for correct naming when converting from modualr +class CsmRMSNorm(LlamaRMSNorm): + pass + + +class CsmRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class CsmMLP(LlamaMLP): + pass + + +class CsmAttention(LlamaAttention): + pass + + +class CsmDecoderLayer(LlamaDecoderLayer): + pass + + +@add_start_docstrings( + "The bare CsmDepthDecoderModel outputting raw hidden-states without any specific head on top.", + CSM_DEPTH_DECODER_START_DOCSTRING, +) +class CsmDepthDecoderModel(LlamaModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`] + + Args: + config: CsmDepthDecoderConfig + """ + + config_class = CsmDepthDecoderConfig + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size) + self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False) + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + backbone_last_hidden_state: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*): + The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model) + is provided in the `input_ids` argument. + """ + if position_ids is not None and not torch.compiler.is_compiling(): + logger.warning_once( + "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids " + "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored." + ) + position_ids = None + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] + device = inputs_embeds.device if inputs_embeds is not None else input_ids.device + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device) + + if inputs_embeds is None: + codebook_idxs = torch.clamp(cache_position - 1, min=0) + offset = codebook_idxs * self.vocab_size + inputs_embeds = self.embed_tokens(input_ids + offset) + + input_ids_are_first_codebook = cache_position[0] == 0 + if backbone_last_hidden_state is not None: + inputs_embeds[:, 0] = backbone_last_hidden_state + else: + if not torch.compiler.is_compiling() and input_ids_are_first_codebook: + logger.warning( + "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference." + ) + + inputs_embeds = self.inputs_embeds_projector(inputs_embeds) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_ids = cache_position.unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class CsmCodebooksHead(nn.Module): + def __init__(self, hidden_size, num_codebooks, vocab_size): + super().__init__() + self.num_codebooks = num_codebooks + self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size)) + + def forward(self, hidden_states, cache_position=None): + if cache_position is None: + seq_length = hidden_states.shape[1] + codebook_weight = self.weight[torch.arange(seq_length)] + else: + codebook_idxs = cache_position - 1 + codebook_weight = self.weight[codebook_idxs] + + hidden_states = [ + nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T) + for codebook_idx in range(codebook_weight.shape[0]) + ] + hidden_states = torch.stack(hidden_states, dim=1) + + return hidden_states + + +@add_start_docstrings( + """ + The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top, + which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook + (e.g. position 0 is the first codebook and uses the first codebook head, etc.) + """, + CSM_DEPTH_DECODER_START_DOCSTRING, +) +class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin): + _tied_weights_keys = None + _tp_plan = None + _pp_plan = None + + def __init__(self, config): + super().__init__(config) + del self.lm_head + self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size) + self.model = CsmDepthDecoderModel(config) + + def get_output_embeddings(self): + raise AttributeError("Not needed for Csm") + + def set_output_embeddings(self, new_embeddings): + raise AttributeError("Not needed for Csm") + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs + ) + + is_first_generation_step = model_inputs["cache_position"][0] == 0 + if not is_first_generation_step: + model_inputs.pop("backbone_last_hidden_state") + + # csm depth decoder does not use position_ids + model_inputs.pop("position_ids") + + return model_inputs + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_DEPTH_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + backbone_last_hidden_state: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*): + The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model) + is provided in the `input_ids` argument. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + backbone_last_hidden_state=backbone_last_hidden_state, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if isinstance(logits_to_keep, int): + if logits_to_keep == 0: + # skip idx 0 logits since it's for the concatenated backbone last hidden state + slice_indices = slice(1, None) + else: + slice_indices = slice(-logits_to_keep, None) + else: + slice_indices = logits_to_keep + + logits = self.codebooks_head( + hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None + ) + logits = logits.contiguous() + + loss = None + if labels is not None: + shift_labels = labels[..., 1:].contiguous() + loss = self.loss_function( + logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class CsmBackboneModelEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size) + self.register_buffer( + "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False + ) + + def forward(self, input_ids): + input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets) + input_embeds = input_embeds.sum(dim=2) + return input_embeds + + +@add_start_docstrings( + "The bare CsmBackboneModel Model outputting raw hidden-states without any specific head on top.", + CSM_START_DOCSTRING, +) +class CsmBackboneModel(LlamaModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CsmDecoderLayer`] + + Args: + config: CsmBackboneModelConfig + """ + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = CsmBackboneModelEmbeddings(config) + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_BACKBONE_INPUTS_DOCSTRING) + def forward(self, **super_kwargs): + return super().forward(**super_kwargs) + + +CSM_INPUTS_DOCSTRING = r"""{}""".format(INPUTS_DOCSTRING_BASE.format(input_ids_docstring=INPUT_IDS_DOCSTRING)) + + +@add_start_docstrings( + """ + The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens. + """, + CSM_START_DOCSTRING, +) +class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): + _tied_weights_keys = [ + "backbone_model.embed_tokens.embed_audio_tokens.weight", + "depth_decoder.model.embed_tokens.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size) + self.backbone_model = CsmBackboneModel._from_config(config) + self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config) + self.codec_model = AutoModel.from_config(config.codec_config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone_model.embed_tokens + + def set_input_embeddings(self, value): + self.backbone_model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _tie_weights(self): + if self.config.tie_codebooks_embeddings: + self._tie_or_clone_weights( + self.backbone_model.embed_tokens.embed_audio_tokens, + self.depth_decoder.model.embed_tokens, + ) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + if kwargs.get("output_loading_info", False): + model, loading_info = super().from_pretrained(*args, **kwargs) + else: + model = super().from_pretrained(*args, **kwargs) + + # copy depth decoder generation conf attr to the depth decoder generation config + prefix = "depth_decoder_" + prefix_len = len(prefix) + depth_decoder_attrs = { + attr[prefix_len:]: value + for attr, value in vars(model.generation_config).items() + if attr.startswith(prefix) + } + + vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs}) + + # remove the depth decoder generation conf attr from the model generation config + for attr in depth_decoder_attrs: + delattr(model.generation_config, prefix + attr) + + if "output_loading_info" in kwargs: + return model, loading_info + else: + return model + + def save_pretrained(self, *args, **kwargs): + # copy the depth decoder generation config attributes to the model generation config + prefix = "depth_decoder_" + depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict() + depth_decoder_attrs.pop("transformers_version", None) + for attr, value in depth_decoder_attrs.items(): + setattr(self.generation_config, prefix + attr, value) + + super().save_pretrained(*args, **kwargs) + + def _merge_input_ids_with_input_values( + self, + input_ids: Optional[torch.Tensor] = None, + input_values: Optional[torch.Tensor] = None, + input_values_cutoffs: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + """ + Merges the input_ids and input_values to produce a single inputs_embeds tensor: + 1 - Infers the codec model on the input_values to retreive codebook token. + 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor. + 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor. + + Args: + input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The input ids to embed. + input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`): + The audio input values to embed. + input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`): + The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio. + """ + inputs_embeds = self.embed_text_tokens(input_ids) + + if input_values is not None: + # infer input_values_mask + input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0)) + audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff() + audio_lengths = audio_lengths[audio_lengths > 0] + input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand( + len(audio_lengths), -1 + ) + input_values_mask = input_values_mask < audio_lengths.unsqueeze(1) + + # ======================================= + # TODO: @eustlb, this should be batched !!! + # but requires making sure batched inference of the codec model works as intended + audio_tokens_list = [] + for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): + batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] + for i in range(batch_input_values_cutoffs.shape[0] - 1): + start_idx = batch_input_values_cutoffs[i] + end_idx = batch_input_values_cutoffs[i + 1] + audio_batch = batch_input_values[..., start_idx:end_idx] + codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) + codebook_ids = codec_outputs.audio_codes.transpose(1, -1) + audio_tokens_list.append(codebook_ids[0]) + + max_audio_frames = max(el.shape[0] for el in audio_tokens_list) + batched_audio_token_ids = torch.stack( + [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] + ) + audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) + # ======================================= + audio_token_id = self.config.audio_token_id + audio_token_mask = input_ids == audio_token_id + + audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids) + inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask] + + # same for the audio eos token + audio_eos_frame_ids = ( + torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long) + * self.config.codebook_eos_token_id + ) + audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1) + + audio_eos_token_mask = input_ids == self.config.audio_eos_token_id + inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1) + + # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks) + if labels is not None: + labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) + labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] + # mask depth decoder + depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) + labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 + labels = labels_expanded + + return {"inputs_embeds": inputs_embeds, "labels": labels} + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None: + merged_inputs = self._merge_input_ids_with_input_values( + input_ids=input_ids, + input_values=kwargs.get("input_values"), + input_values_cutoffs=kwargs.get("input_values_cutoffs"), + labels=kwargs.get("labels"), + ) + model_inputs.update( + {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None} + ) + + return model_inputs + + @can_return_tuple + @add_start_docstrings_to_model_forward(CSM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CsmOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + input_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + input_values_cutoffs: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CsmOutputWithPast]: + r""" + input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*): + Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input. + If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences + where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2, + the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]]. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`. + Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`. + - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames) + - `-100` will be ignored in the loss computation + - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels) + + Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`]. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + Kept for compatibility. Does not support another value than: + 1. `0`, which is equivalent to keeping all logits, used in the training regime + 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import CsmForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset, Audio + + >>> model_id = "eustlb/csm-1b" + >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> processor = AutoProcessor.from_pretrained(model_id) + + >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + >>> # ensure the audio is 24kHz + >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) + + >>> conversation = [] + >>> # prepare a conversation with text and corresponding audio + >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): + ... conversation.append( + ... { + ... "role": f"{speaker_id}", + ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], + ... } + ... ) + + >>> inputs = processor.apply_chat_template( + ... conversation, + ... tokenize=True, + ... return_dict=True, + ... output_labels=True, + ... ).to(torch_device) + + >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) + >>> output = model(**inputs) + >>> output.loss.backward() + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is not None and input_ids.ndim == 2: + merged_inputs = self._merge_input_ids_with_input_values( + input_ids, input_values, input_values_cutoffs, labels + ) + inputs_embeds = merged_inputs["inputs_embeds"] + labels = merged_inputs["labels"] + input_ids = None + + backbone_outputs = self.backbone_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + backbone_hidden_states = backbone_outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :]) + + loss = None + backbone_loss = None + depth_decoder_loss = None + depth_decoder_outputs = None + if labels is not None: + # select first codebook as labels for the backbone model + backbone_labels = labels[:, :, 0] + backbone_loss = self.loss_function( + logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + # for the depth decoder, we need to select the frames to train on + # those are frames where the label is not uniformly `ignore_index` along the codebook dimension + train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1) + depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1] + # add place holder in position 0 that will be replaced by the backbone_last_hidden_state + depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0) + + train_idxs = train_mask.nonzero(as_tuple=True) + backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :] + depth_decoder_labels = labels[train_mask] + + depth_decoder_outputs = self.depth_decoder( + input_ids=depth_decoder_input_ids, + backbone_last_hidden_state=backbone_last_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + labels=depth_decoder_labels, + ) + + depth_decoder_loss = depth_decoder_outputs.loss + loss = backbone_loss + depth_decoder_loss + + return CsmOutputWithPast( + loss=loss, + backbone_loss=backbone_loss, + depth_decoder_loss=depth_decoder_loss, + logits=backbone_logits, + past_key_values=backbone_outputs.past_key_values, + hidden_states=backbone_outputs.hidden_states, + attentions=backbone_outputs.attentions, + depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None, + depth_decoder_past_key_values=depth_decoder_outputs.past_key_values + if depth_decoder_outputs is not None + else None, + depth_decoder_hidden_states=depth_decoder_outputs.hidden_states + if depth_decoder_outputs is not None + else None, + depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, + ) + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = [ + "CsmPreTrainedModel", + "CsmBackboneModel", + "CsmDepthDecoderModel", + "CsmDepthDecoderForCausalLM", + "CsmForConditionalGeneration", +] diff --git a/src/transformers/models/csm/processing_csm.py b/src/transformers/models/csm/processing_csm.py new file mode 100644 index 0000000000..486c5eda4c --- /dev/null +++ b/src/transformers/models/csm/processing_csm.py @@ -0,0 +1,364 @@ +# coding=utf-8 +# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from ...utils import is_soundfile_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_soundfile_available(): + import soundfile as sf + +from ...audio_utils import AudioInput, make_list_of_audio +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import ( + PreTokenizedInput, + TextInput, +) + + +class CsmAudioKwargs(AudioKwargs, total=False): + encoded_length_kwargs: Optional[Dict[str, Any]] + + +class CsmProcessorKwargs(ProcessingKwargs, total=False): + audio_kwargs: CsmAudioKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "padding_side": "left", + "add_special_tokens": False, + }, + "audio_kwargs": { + "encoded_length_kwargs": { + "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4], + "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2], + "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "use_causal_conv": True, + }, + "sampling_rate": 24000, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class CsmProcessor(ProcessorMixin): + r""" + Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and + [`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and + tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more + information. + The preferred way of passing kwargs is as a dictionary per modality, see usage example below. + ```python + from transformers import CsmProcessor + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + audio = ds[0]["audio"]["array"] + + processor = CsmProcessor.from_pretrained("eustlb/csm-1b") + + processor( + text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"], + audio=audio, + text_kwargs = {"padding": False}, + audio_kwargs = {"sampling_rate": 16000}, + common_kwargs = {"return_tensors": "pt"}, + ) + # this should error out because EncodecFeatureExtractor expects a 24kHz audio :) + ``` + + Args: + feature_extractor ([`EncodecFeatureExtractor`]): + The feature extractor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + + """ + + attributes = ["feature_extractor", "tokenizer"] + valid_kwargs = ["chat_template"] + feature_extractor_class = "EncodecFeatureExtractor" + tokenizer_class = "PreTrainedTokenizerFast" + + def __init__( + self, + feature_extractor, + tokenizer, + chat_template=None, + ): + if not hasattr(tokenizer, "audio_token"): + self.audio_token = "<|AUDIO|>" + self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) + else: + self.audio_token = tokenizer.audio_token + self.audio_token_id = tokenizer.audio_token_id + + if not hasattr(tokenizer, "audio_eos_token"): + self.audio_eos_token = "<|audio_eos|>" + self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token) + else: + self.audio_eos_token = tokenizer.audio_eos_token + self.audio_eos_token_id = tokenizer.audio_eos_token_id + + super().__init__(feature_extractor, tokenizer, chat_template=chat_template) + + @staticmethod + def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None): + """ + Compute the length of the encoded audio sequence. + + Args: + audio_length (int): The length of the audio sequence. + kernel_sizes (List[int]): The kernel sizes for the convolutional layers. + strides (List[int]): The strides for the convolutional layers. + use_causal_conv (bool): Whether to use causal convolutions. + """ + cur_length = audio_length + + if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None: + return cur_length + + for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations): + effective_kernel_size = (kernel_size - 1) * dilation + 1 + padding_total = kernel_size - stride + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + + n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1 + n_frames = math.ceil(n_frames) - 1 + ideal_length = n_frames * stride + kernel_size - padding_total + extra_padding = ideal_length - cur_length + + if use_causal_conv: + padding_left = padding_total + padding_right = extra_padding + else: + padding_left = padding_left + padding_right = padding_right + extra_padding + + cur_length = cur_length + padding_left + padding_right + cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1 + + return cur_length + + def save_audio( + self, + audio: AudioInput, + saving_path: Union[str, Path, List[Union[str, Path]]], + **kwargs: Unpack[CsmProcessorKwargs], + ): + # TODO: @eustlb, this should be in AudioProcessor + if not is_soundfile_available(): + raise ImportError("Please install `soundfile` to save audio files.") + + # ensure correct audio input + audio = make_list_of_audio(audio) + + # ensure correct saving path + if isinstance(saving_path, (str, Path)): + saving_path = [saving_path] + elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)): + raise ValueError("Invalid input path. Please provide a string, or a list of strings") + + if len(audio) != len(saving_path): + raise ValueError("The number of audio and saving paths must be the same") + + output_kwargs = self._merge_kwargs( + CsmProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + sampling_rate = audio_kwargs["sampling_rate"] + + for audio_value, p in zip(audio, saving_path): + if isinstance(audio_value, torch.Tensor): + audio_value = audio_value.cpu().float().numpy() + sf.write(p, audio_value, sampling_rate) + + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]], + audio: Optional[AudioInput] = None, + output_labels: Optional[bool] = False, + depth_decoder_labels_ratio: Optional[float] = 1.0, + **kwargs: Unpack[CsmProcessorKwargs], + ): + r""" + Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text` + arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode + the text. To prepare the audio, this method forwards the `audio` arguments to + EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer + to the docstring of the above two methods for more information. + + Args: + audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch + tensor. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + output_labels (bool, *optional*, default=False): + Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`. + - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames) + - `-100` will be ignored in the loss computation + - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels) + depth_decoder_labels_ratio (float, *optional*, default=1.0): + The ratio of audio frames to keep for the depth decoder labels. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`. + """ + + output_kwargs = self._merge_kwargs( + CsmProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + text_kwargs = output_kwargs["text_kwargs"] + audio_kwargs = output_kwargs["audio_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + return_tensors = common_kwargs.pop("return_tensors", None) + if return_tensors != "pt": + raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.") + + if isinstance(text, str): + text = [text] + elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + n_audio_in_text = [t.count(self.audio_token) for t in text] + + n_audio = 0 + if audio is not None: + audio = make_list_of_audio(audio) + n_audio = len(audio) + + if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text): + if audio is None: + raise ValueError("No audio were provided, but there are audio tokens in the prompt") + else: + raise ValueError( + f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the " + f"number of provided audios ({n_audio})." + ) + + if audio is not None: + encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {}) + num_audio_tokens_list = [ + self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio + ] + num_audio_tokens_list_copy = num_audio_tokens_list.copy() + + # expand the text to repeat the audio token for the corresponding number of frames + expanded_text = [] + for sample in text: + replace_str = [] + while self.audio_token in sample: + num_audio_tokens = num_audio_tokens_list_copy.pop(0) + expanded_audio_token = self.audio_token * num_audio_tokens + + replace_str.append(expanded_audio_token) + sample = sample.replace(self.audio_token, "", 1) + + while "" in sample: + sample = sample.replace("", replace_str.pop(0), 1) + expanded_text.append(sample) + + text = expanded_text + + encoding = self.tokenizer(text, **text_kwargs) + data = {} + data.update(encoding) + + if audio is not None: + audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor + + concatenated_audio, input_values_cutoffs = [], [] + offset = 0 + for n_audio in n_audio_in_text: + if n_audio == 0: + concatenated_audio.append(np.zeros(0)) + input_values_cutoffs.append(torch.tensor([-1])) + else: + concatenated_audio.append( + np.concatenate( + [ + el.cpu().numpy() if isinstance(el, torch.Tensor) else el + for el in audio[offset : offset + n_audio] + ], + axis=-1, + ) + ) + input_values_cutoffs.append( + torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1) + ) + offset += n_audio + + audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs) + audio_inputs.pop("padding_mask", None) # not applicable here + data.update(audio_inputs) + + # pad and stack the audio cut idxs + max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs) + input_values_cutoffs = [ + torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1) + for cut_idxs in input_values_cutoffs + ] + data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0) + + if output_labels: + audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero() + n_audio_frames = audio_frame_idxs.shape[0] + + if depth_decoder_labels_ratio <= 1.0: + rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))] + skip_frames_idxs = audio_frame_idxs[rand_idxs] + else: + skip_frames_idxs = audio_frame_idxs + + labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100) + labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101 + + data["labels"] = labels + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["CsmProcessor"] diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 03550e09ed..1ba8536839 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1111,7 +1111,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin): ) return model_inputs - def _get_initial_cache_position(self, input_ids, model_kwargs): + def _get_initial_cache_position(self, seq_length, device, model_kwargs): """ Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length. Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`. @@ -1125,8 +1125,8 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin): if "inputs_embeds" in model_kwargs: cur_len = model_kwargs["inputs_embeds"].shape[1] else: - cur_len = input_ids.shape[-1] - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + cur_len = seq_length + model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device) return model_kwargs @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 6a30954345..853e371c89 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1563,7 +1563,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): inputs_embeds = self.get_input_embeddings()(input_tokens) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs) if model_kwargs.get("past_key_values", None) is None: # Prepare cache if not provided. diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 36499e43ea..50cb0021bf 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1378,7 +1378,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): inputs_embeds = self.get_input_embeddings()(input_tokens) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs) if model_kwargs.get("past_key_values", None) is None: # Prepare cache if not provided. diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index b8f24040dd..bfe6698c55 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -216,6 +216,32 @@ class MimiConv1d(nn.Module): end = padded.shape[-1] - extra_pad return padded[..., :end] + def _get_output_length(self, input_length: torch.LongTensor) -> torch.LongTensor: + """ + Return the length of the output of the MimiConv1d. + """ + # padding size + n_frames = (input_length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = torch.ceil(n_frames).to(torch.int64) - 1 + ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total + extra_padding = ideal_length - input_length + + if self.causal: + padding_left = self.padding_total + padding_right = extra_padding + else: + padding_left = self.padding_left + padding_right = self.padding_right + extra_padding + + # padding + input_length = input_length + padding_left + padding_right + + # conv + output_lenght = ( + input_length + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1 + ) // self.conv.stride[0] + 1 + return output_lenght + def forward(self, hidden_states): extra_padding = self._get_extra_padding_for_conv1d(hidden_states) @@ -331,21 +357,28 @@ class MimiEncoder(nn.Module): model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)] scaling = 1 + # keep track of MimiConv1d submodule layer names for easy encoded length computation + mimiconv1d_layer_names = ["layers.0"] + # Downsample to raw audio scale for ratio in reversed(config.upsampling_ratios): current_scale = scaling * config.num_filters # Add residual layers for j in range(config.num_residual_layers): + mimiconv1d_layer_names.extend([f"layers.{len(model)}.block.1", f"layers.{len(model)}.block.3"]) model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])] # Add downsampling layers model += [nn.ELU()] + mimiconv1d_layer_names.append(f"layers.{len(model)}") model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)] scaling *= 2 model += [nn.ELU()] + mimiconv1d_layer_names.append(f"layers.{len(model)}") model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)] self.layers = nn.ModuleList(model) + self._mimiconv1d_layer_names = mimiconv1d_layer_names # Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward def forward(self, hidden_states): @@ -1567,6 +1600,38 @@ class MimiModel(MimiPreTrainedModel): codes = codes.transpose(0, 1) return codes, past_key_values + def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor: + """ + Return the number of frames of the encoded audio waveform. + """ + output_length = input_length + + # encoder + for layer_name in self.encoder._mimiconv1d_layer_names: + output_length = self.encoder.get_submodule(layer_name)._get_output_length(output_length) + + # downsample + output_length = self.downsample._get_output_length(output_length) + + return output_length + + def get_audio_codes_mask(self, padding_mask: torch.Tensor, padding_side: str = "right"): + """ + Get the mask for the audio codes from the original padding mask. + """ + encoded_lengths = self.get_encoded_length(padding_mask.sum(dim=-1)) + + audio_codes_mask = torch.arange(encoded_lengths.max(), device=encoded_lengths.device).expand( + len(encoded_lengths), -1 + ) + audio_codes_mask = audio_codes_mask < encoded_lengths.unsqueeze(1) + audio_codes_mask = audio_codes_mask.to(padding_mask.device) + + if padding_side == "right": + return audio_codes_mask + else: + return audio_codes_mask.flip(dims=[-1]) + def encode( self, input_values: torch.Tensor, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 19c88c047a..9737e1437e 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -3084,10 +3084,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon thinker_reply_part=thinker_reply_part, ) - def _get_initial_cache_position(self, input_ids, model_kwargs): + def _get_initial_cache_position(self, seq_length, device, model_kwargs): # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily inputs_embeds = model_kwargs.pop("inputs_embeds") - model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs) model_kwargs["inputs_embeds"] = inputs_embeds return model_kwargs diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 2123be2903..9b5c416764 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2771,10 +2771,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon thinker_reply_part=thinker_reply_part, ) - def _get_initial_cache_position(self, input_ids, model_kwargs): + def _get_initial_cache_position(self, seq_length, device, model_kwargs): # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily inputs_embeds = model_kwargs.pop("inputs_embeds") - model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs) + model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs) model_kwargs["inputs_embeds"] = inputs_embeds return model_kwargs diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 403569ae89..d07dad0205 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1058,7 +1058,7 @@ class ProcessorMixin(PushToHubMixin): # update defaults with arguments from tokenizer init for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys(): # init with tokenizer init kwargs if necessary - if modality_key in tokenizer_init_kwargs: + if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs: value = ( getattr(self.tokenizer, modality_key) if hasattr(self.tokenizer, modality_key) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bb4965cb76..20edc1c897 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -501,9 +501,9 @@ class GenerationTesterMixin: output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) @pytest.mark.generate def test_greedy_generate_dict_outputs(self): @@ -525,13 +525,13 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check @@ -565,10 +565,10 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self._check_generate_outputs(output_generate, model.config, use_cache=True) @@ -582,9 +582,9 @@ class GenerationTesterMixin: output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) @pytest.mark.generate def test_sample_generate_dict_output(self): @@ -607,13 +607,13 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check @@ -632,9 +632,9 @@ class GenerationTesterMixin: output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) @pytest.mark.generate def test_beam_search_generate_dict_output(self): @@ -657,13 +657,13 @@ class GenerationTesterMixin: use_cache=False, ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check @@ -706,10 +706,10 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self._check_generate_outputs( @@ -759,9 +759,9 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) @pytest.mark.generate def test_beam_sample_generate_dict_output(self): @@ -786,13 +786,13 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check @@ -840,9 +840,9 @@ class GenerationTesterMixin: beam_kwargs=beam_kwargs, ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) # check `group_beam_search` for higher than 1 `num_return_sequences` num_return_sequences = 2 @@ -853,9 +853,9 @@ class GenerationTesterMixin: beam_kwargs=beam_kwargs, ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) @pytest.mark.generate def test_group_beam_search_generate_dict_output(self): @@ -878,13 +878,13 @@ class GenerationTesterMixin: use_cache=False, ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check @@ -923,9 +923,9 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -947,9 +947,9 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -987,13 +987,13 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check @@ -1031,9 +1031,9 @@ class GenerationTesterMixin: use_cache=True, # Enable cache ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) + self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]) @pytest.mark.generate def test_contrastive_generate_dict_outputs_use_cache(self): @@ -1067,10 +1067,10 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self._check_generate_outputs(output_generate, model.config, use_cache=True) @@ -1499,7 +1499,7 @@ class GenerationTesterMixin: position_ids.masked_fill_(attention_mask == 0, 1) model_kwargs["position_ids"] = position_ids if "cache_position" in signature: - cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + cache_position = torch.arange(input_ids.shape[1], device=torch_device) model_kwargs["cache_position"] = cache_position return model_kwargs @@ -1525,10 +1525,12 @@ class GenerationTesterMixin: pad_token_id = ( config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 ) - pad_size = (input_ids.shape[0], 32) + pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:]) padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id padded_input_ids = torch.cat((padding, input_ids), dim=1) - padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + padded_attention_mask = torch.cat( + (torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1 + ) model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] @@ -1587,7 +1589,7 @@ class GenerationTesterMixin: else text_config.num_attention_heads ) encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads - batch_size, seq_length = inputs["decoder_input_ids"].shape + batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] # The sequence length for the encoder K V depends on the model. Since it is not manipulated in # autoregressive generation, we're keeping the test general and not checking the 3rd dim default_cross_attention_shape = ( @@ -1606,7 +1608,7 @@ class GenerationTesterMixin: for _ in range(num_decoder_layers) ] else: - batch_size, seq_length = inputs["input_ids"].shape + batch_size, seq_length = inputs["input_ids"].shape[:2] default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) all_cache_shapes = [ [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) @@ -1727,7 +1729,7 @@ class GenerationTesterMixin: "min_new_tokens": 5, # generate exactly 5 tokens } outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output). # The output of the two calls should be the same. @@ -2262,11 +2264,11 @@ class GenerationTesterMixin: self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) else: self.assertTrue( - output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1] ) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) @@ -2408,7 +2410,7 @@ class GenerationTesterMixin: config = config.text_config if hasattr(config, "text_config") else config generated_length = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length + output.sequences.shape[1] - 1 if config.is_encoder_decoder else output.sequences.shape[1] - prompt_length ) decoder_past_key_values = getattr(output, "past_key_values", None) if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache): @@ -2441,7 +2443,7 @@ class GenerationTesterMixin: batch_size=internal_batch_size, attentions=output.decoder_attentions, prompt_length=1, # the BOS token - output_length=output.sequences.shape[-1], + output_length=output.sequences.shape[1], config=config, decoder_past_key_values=decoder_past_key_values, ) @@ -2450,7 +2452,7 @@ class GenerationTesterMixin: batch_size=internal_batch_size, attentions=output.attentions, prompt_length=prompt_length, - output_length=output.sequences.shape[-1], + output_length=output.sequences.shape[1], config=config, decoder_past_key_values=decoder_past_key_values, ) @@ -2469,7 +2471,7 @@ class GenerationTesterMixin: batch_size=internal_batch_size, hidden_states=output.decoder_hidden_states, prompt_length=1, # the BOS token - output_length=output.sequences.shape[-1], + output_length=output.sequences.shape[1], config=config, use_cache=use_cache, ) @@ -2478,7 +2480,7 @@ class GenerationTesterMixin: batch_size=internal_batch_size, hidden_states=output.hidden_states, prompt_length=prompt_length, - output_length=output.sequences.shape[-1], + output_length=output.sequences.shape[1], config=config, use_cache=use_cache, ) @@ -2506,7 +2508,7 @@ class GenerationTesterMixin: ) if has_standard_cache: if use_cache: - cache_length = output.sequences.shape[-1] - 1 + cache_length = output.sequences.shape[1] - 1 self._check_past_key_values_for_generate( batch_size=internal_batch_size, decoder_past_key_values=decoder_past_key_values, diff --git a/tests/models/csm/__init__.py b/tests/models/csm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/csm/test_modeling_csm.py b/tests/models/csm/test_modeling_csm.py new file mode 100644 index 0000000000..93423598ba --- /dev/null +++ b/tests/models/csm/test_modeling_csm.py @@ -0,0 +1,693 @@ +# coding=utf-8 +# Copyright 2024, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch ConversationalSpeechModel model.""" + +import collections +import copy +import re +import unittest + +import pytest +from parameterized import parameterized + +from transformers import ( + AutoProcessor, + CsmConfig, + CsmForConditionalGeneration, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_torch_gpu, + slow, + torch_device, +) +from transformers.utils.import_utils import is_datasets_available + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + ids_tensor, +) + + +if is_datasets_available(): + from datasets import load_dataset + +if is_torch_available(): + import torch + + from transformers.pytorch_utils import id_tensor_storage + + +class CsmModelTester: + def __init__( + self, + parent, + ignore_index=-100, + batch_size=3, + seq_length=7, + is_training=True, + depth_decoder_config={ + "num_codebooks": 10, + "backbone_hidden_size": 64, + "vocab_size": 6, + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "hidden_act": "silu", + "max_position_embeddings": 10, + }, + codec_config={ + "model_type": "mimi", + "audio_channels": 1, + "chunk_in_sec": None, + "hidden_size": 32, + "num_filters": 8, + "num_residual_layers": 1, + "upsampling_ratios": [8, 4], + "codebook_size": 64, + "vector_quantization_hidden_dimension": 64, + "upsample_groups": 32, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "sliding_window": 4, + "codebook_dim": 64, + "use_cache": False, + }, + config={ + "num_codebooks": 10, + "vocab_size": 6, + "text_vocab_size": 99, + "hidden_size": 64, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "hidden_act": "silu", + "max_position_embeddings": 10, + "bos_token_id": 1, + "pad_token_id": 2, + "eos_token_id": 3, + "codebook_pad_token_id": 2, + "codebook_eos_token_id": 3, + }, + ): + self.parent = parent + self.is_training = is_training + self.ignore_index = ignore_index + self.depth_decoder_config = depth_decoder_config + self.codec_config = codec_config + self.config = config + self.seq_length = seq_length + self.batch_size = batch_size + + self.num_hidden_layers = config["num_hidden_layers"] + self.vocab_size = config["vocab_size"] + self.hidden_size = config["hidden_size"] + self.num_attention_heads = config["num_attention_heads"] + self.pad_token_id = config["pad_token_id"] + + def get_config(self): + return CsmConfig( + depth_decoder_config=self.depth_decoder_config, + codec_config=self.codec_config, + **self.config, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + input_ids = ids_tensor([self.batch_size, self.seq_length, config.num_codebooks], config.vocab_size - 1) + 1 + attention_mask = input_ids[..., -1].ne(1).to(torch_device) + return config, input_ids, attention_mask + + def prepare_config_and_inputs_for_common(self): + config, input_ids, attention_mask = self.prepare_config_and_inputs() + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (CsmForConditionalGeneration,) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_resize_embeddings = False + test_resize_embeddings_untied = False + test_torch_exportable = True + + def setUp(self): + self.model_tester = CsmModelTester(self) + self.config_tester = ConfigTester(self, config_class=CsmConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + """ + Overrides [ModelTesterMixin._prepare_for_class] to handle third input_ids dimension. + """ + inputs_dict = copy.deepcopy(inputs_dict) + + if return_labels: + inputs_dict["labels"] = torch.zeros( + ( + self.model_tester.batch_size, + self.model_tester.seq_length, + self.model_tester.config["num_codebooks"], + ), + dtype=torch.long, + device=torch_device, + ) + + return inputs_dict + + def _get_logits_processor_kwargs(self, do_sample=False, config=None): + """ + Overrides [GenerationTesterMixin._get_logits_processor_kwargs] to restrict to top_k, top_p, and temperature sampling. + """ + logits_processor_kwargs = {} + if do_sample: + logits_processor_kwargs.update( + { + "top_k": 10, + "top_p": 0.7, + "temperature": 0.7, + } + ) + + return logits_processor_kwargs + + def test_initialization(self): + """ + Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model. + See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397 + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv", "input_proj", "output_proj"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5): + """ + Overrides [GenerationTesterMixin._check_similar_generate_outputs] to handle third input_ids dimension. + Here we only look a the first codebook (index 0 on last dimension of the generated sequences) since returned scores + are for this token. + """ + # scores doesn't include data regarding decoder input tokens + decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores) + output_matches = output_1.sequences[..., 0] == output_2.sequences[..., 0] + has_matching_outputs = output_matches.all() + has_matching_scores = None + if not has_matching_outputs: + for batch_idx in range(output_1.sequences.shape[0]): + batch_matches = output_matches[batch_idx] + if batch_matches.all(): + continue + first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False + first_mismatch_idx -= decoder_input_length + output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx] + output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx] + has_matching_scores = torch.allclose( + output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol + ) + if not has_matching_scores: + break + self.assertTrue(has_matching_outputs or has_matching_scores) + + @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate + @unittest.skip(reason="CSM does not support assisted decoding.") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support assisted decoding.") + def test_assisted_decoding_sample(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support Dola decoding.") + def test_dola_decoding_sample(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_beam_sample_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_beam_search_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_beam_search_generate_dict_output(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_beam_sample_generate_dict_output(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support group beam search.") + def test_group_beam_search_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support group beam search.") + def test_group_beam_search_generate_dict_output(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support constrained beam search.") + def test_constrained_beam_search_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support constrained beam search.") + def test_constrained_beam_search_generate_dict_output(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support contrastive search.") + def test_contrastive_generate(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support contrastive search.") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support contrastive search.") + def test_contrastive_generate_low_memory(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support prompt lookup decoding.") + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support prompt lookup decoding.") + def test_prompt_lookup_decoding_stops_at_eos(self): + pass + + @pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).") + def test_model_get_set_embeddings(self): + pass + + @pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).") + def test_tie_model_weights(self): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams): + pass + + @pytest.mark.generate + @unittest.skip(reason="CSM does not support beam search.") + def test_model_parallel_beam_search(self): + pass + + def test_tied_weights_keys(self): + """ + Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM). + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model_tied = model_class(config) + + ptrs = collections.defaultdict(list) + for name, tensor in model_tied.state_dict().items(): + ptrs[id_tensor_storage(tensor)].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + + tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] + # Detect we get a hit for each key + for key in tied_weight_keys: + is_tied_key = any(re.search(key, p) for group in tied_params for p in group) + self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.") + + # Removed tied weights found from tied params -> there should only be one left after + for key in tied_weight_keys: + for i in range(len(tied_params)): + tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None] + + tied_params = [group for group in tied_params if len(group) > 1] + self.assertListEqual( + tied_params, + [], + f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.", + ) + + def _get_custom_4d_mask_test_data(self): + """ + Overrides [ModelTesterMixin._get_custom_4d_mask_test_data] to handle third input_ids dimension. + """ + # Sequence in which all but the last token is the same + input_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 4], [0, 1, 2, 5]], device=torch_device, dtype=torch.int64) + input_ids = input_ids.unsqueeze(-1).expand(-1, -1, self.model_tester.config["num_codebooks"]) + position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64) + + # Combining common prefix with the unique ending tokens: + input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0) + + # Creating a 4D mask where each of the last 3 tokens do not attend to each other. + mask_shared_prefix = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 1, 0], + [1, 1, 1, 0, 0, 1], + ] + ] + ], + ) + # inverting the attention mask + mask_dtype = torch.float32 + min_dtype = torch.finfo(mask_dtype).min + mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype + + # Creating a position_ids tensor. note the repeating figures in the end. + position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64) + + return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix + + +class CsmForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + # TODO: @eustlb, update with correct sesame's repo + self.model_checkpoint = "eustlb/csm-1b" + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def _load_conversation(self): + ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + ds = ds.filter(lambda x: x["conversation_id"] == 0) + ds = ds.sort("turn_id") + return ds[0] + + @slow + @require_torch_gpu + def test_1b_model_integration_generate(self): + """ + Tests the generated tokens match the ones from the original model implementation. + Such tokens are to be retreived using https://gist.github.com/eustlb/d25577a357ddcf8f4a8cd0d00baca551, which is a script that infers the original model. + """ + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + prompt = "<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>" + + ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + audio = ds[0]["audio"]["array"] + inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(torch_device) + + model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device) + output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False) + + # fmt: off + EXPECTED_OUTPUT_TOKENS = torch.tensor([[ + [1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601], + [1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 598, 80, 223, 1798, 251, 385, 1391, 1692, 1228, 1631, 1101, 866], + [778, 645, 830, 1812, 524, 1704, 1805, 1289, 74, 1069, 243, 1622, 1755, 1281, 1397, 620, 1962, 1995, 253, 1124, 1007, 518, 89, 559, 1304, 1482, 523, 1747, 1979, 1003, 1707, 1578], + [1356, 481, 642, 989, 287, 1819, 171, 1115, 824, 1253, 1488, 1074, 1019, 342, 279, 513, 1275, 1364, 893, 2007, 553, 407, 882, 1170, 1586, 485, 762, 559, 100, 542, 911, 1460], + [1860, 593, 1944, 404, 575, 545, 862, 830, 1002, 125, 2010, 268, 1779, 804, 811, 809, 255, 373, 387, 1756, 259, 822, 1191, 700, 1686, 390, 1676, 844, 2006, 286, 1376, 719], + [1165, 1047, 848, 212, 1018, 1470, 93, 1709, 1487, 1691, 1190, 275, 1278, 2018, 121, 1023, 485, 463, 39, 1825, 1936, 1817, 569, 209, 1553, 1599, 1137, 769, 968, 558, 1957, 265], + [902, 1608, 719, 850, 371, 1920, 75, 1917, 2005, 1238, 562, 1743, 713, 95, 1107, 1463, 696, 840, 8, 487, 1950, 1171, 1004, 1516, 1130, 303, 1866, 1728, 2046, 238, 265, 153], + [1932, 839, 334, 1167, 134, 2025, 40, 505, 1244, 1238, 1840, 800, 697, 72, 216, 486, 940, 1312, 510, 361, 549, 583, 1364, 844, 397, 1181, 1779, 962, 457, 1782, 1316, 465], + [31, 1558, 1048, 404, 354, 7, 827, 414, 1082, 807, 243, 1517, 801, 1364, 99, 1276, 1655, 1488, 1313, 464, 828, 1612, 774, 1558, 745, 1496, 960, 1874, 995, 1943, 255, 213], + [355, 1270, 413, 1519, 1659, 1904, 690, 552, 1279, 1821, 2022, 458, 1779, 2003, 604, 832, 661, 1295, 305, 1701, 173, 869, 230, 539, 1188, 669, 117, 692, 250, 388, 1995, 294], + [629, 199, 1899, 1123, 1070, 344, 578, 1795, 1451, 1257, 168, 1410, 1120, 1270, 316, 983, 1245, 1870, 165, 471, 966, 1337, 308, 1118, 746, 67, 1767, 1480, 1517, 1585, 871, 1110], + [1281, 1173, 784, 404, 368, 403, 580, 526, 853, 1692, 792, 895, 1286, 573, 1368, 896, 931, 1958, 1912, 644, 583, 1706, 1176, 1262, 1637, 315, 524, 1629, 795, 1211, 915, 533], + [9, 1783, 621, 1954, 1212, 993, 197, 977, 1662, 1340, 618, 1997, 1689, 1001, 74, 1765, 1865, 797, 1219, 1609, 671, 1491, 950, 1849, 1301, 2031, 875, 323, 203, 1063, 1490, 1538], + [1944, 1578, 1256, 1169, 790, 1444, 1382, 1616, 1100, 1264, 214, 1646, 488, 573, 1333, 285, 1954, 74, 1333, 674, 1303, 266, 622, 1290, 402, 109, 1331, 1666, 1347, 780, 106, 605], + [221, 161, 1322, 1, 565, 1507, 1403, 1091, 1557, 932, 1664, 1165, 1828, 1647, 2008, 1616, 648, 1113, 1870, 22, 734, 1458, 1940, 1756, 1689, 925, 1318, 1095, 985, 473, 604, 1974], + [1178, 597, 1804, 747, 1383, 360, 1497, 406, 1053, 1023, 1901, 56, 1221, 628, 75, 1729, 575, 1681, 840, 410, 650, 794, 1171, 1889, 187, 54, 1364, 1390, 505, 1285, 1814, 90], + [1432, 1221, 1800, 1873, 1255, 627, 41, 9, 630, 896, 1469, 1195, 1098, 145, 442, 1460, 13, 57, 2039, 1015, 149, 461, 1084, 1288, 1099, 910, 63, 157, 906, 111, 1394, 460], + [1352, 593, 307, 780, 1614, 1675, 1491, 1253, 723, 1793, 1032, 1486, 1805, 1904, 777, 398, 1791, 951, 770, 499, 1858, 244, 1372, 1514, 1858, 1200, 69, 181, 673, 1144, 1938, 1191], + [905, 403, 1626, 1529, 581, 1443, 976, 754, 1561, 1370, 1048, 253, 194, 1271, 853, 959, 1532, 30, 286, 1594, 1255, 1135, 1410, 1699, 1423, 2002, 260, 69, 941, 1640, 895, 722], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ]]) + # fmt: on + + torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS) + + @slow + @require_torch_gpu + def test_1b_model_integration_generate_no_audio(self): + """ + Tests the generated tokens match the ones from the original model implementation. + Such tokens are to be retreived using https://gist.github.com/eustlb/aed822f765e928b9612e01b0d8836d69, which is a script that infers the original model. + """ + + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + + conversation = [ + {"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]}, + ] + + inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True).to(torch_device) + + model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device) + output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False) + + print(output_tokens) + # fmt: off + EXPECTED_OUTPUT_TOKENS = torch.tensor([[ + [1656, 629, 723, 1785, 206, 1873, 1059, 1190, 1833, 240, 618, 350, 156, 109, 2010, 452, 435, 1764, 77, 654, 1133, 908, 1095, 74, 804, 494, 1760, 1343, 1312, 1464, 1657, 324], + [366, 1532, 1945, 21, 145, 1428, 1417, 1987, 1793, 1444, 356, 1491, 849, 333, 788, 426, 1423, 1004, 414, 1823, 1169, 257, 1892, 696, 1572, 998, 1098, 523, 390, 1977, 546, 1692], + [1343, 1382, 1288, 1744, 1685, 1154, 1837, 1156, 1680, 1641, 1479, 1548, 632, 824, 694, 2010, 671, 1251, 1822, 343, 638, 1372, 696, 1272, 144, 125, 1332, 579, 936, 77, 159, 357], + [456, 1534, 349, 274, 1956, 1502, 1268, 1038, 1911, 523, 1360, 1159, 761, 293, 718, 1143, 63, 705, 168, 550, 413, 1372, 1771, 787, 631, 693, 784, 1789, 2039, 1131, 1601, 918], + [456, 829, 2026, 1108, 1649, 207, 1308, 1440, 1192, 1394, 426, 546, 590, 36, 1682, 1827, 1387, 1425, 1909, 1500, 1438, 1297, 5, 888, 948, 1745, 1304, 1364, 1692, 131, 300, 1908], + [2027, 1431, 1037, 1789, 1296, 1264, 1331, 1787, 1235, 1902, 1161, 1591, 590, 561, 1633, 1218, 510, 148, 1962, 118, 212, 608, 565, 1869, 583, 598, 532, 658, 1416, 9, 1172, 493], + [1215, 460, 1722, 317, 1423, 716, 1589, 1177, 1927, 1860, 1756, 1552, 1674, 643, 74, 1256, 587, 1742, 771, 2028, 469, 1070, 1683, 1614, 699, 494, 2020, 139, 1365, 1171, 171, 904], + [1615, 339, 323, 317, 469, 714, 104, 2015, 1407, 278, 468, 77, 2007, 650, 1630, 269, 168, 934, 1544, 58, 1487, 1373, 705, 874, 1252, 2031, 1995, 254, 1334, 1171, 1911, 1607], + [1259, 693, 666, 1700, 1115, 607, 982, 769, 1106, 1500, 101, 88, 1698, 1864, 1358, 1594, 192, 153, 1868, 1654, 604, 1948, 526, 778, 172, 1664, 1966, 99, 1334, 1030, 1349, 1209], + [1211, 579, 1369, 492, 1725, 203, 1125, 778, 701, 1982, 1420, 155, 736, 1145, 2018, 609, 658, 561, 1147, 923, 1794, 1753, 116, 1374, 612, 956, 1587, 392, 1062, 2047, 901, 1931], + [460, 1093, 1346, 1917, 1223, 470, 271, 390, 547, 112, 143, 1633, 1030, 643, 96, 1759, 920, 1959, 75, 1280, 1630, 999, 333, 853, 1110, 1291, 1911, 57, 171, 1658, 1704, 1508], + [908, 500, 393, 184, 1437, 482, 2008, 1834, 356, 1435, 1550, 1407, 1236, 109, 1167, 452, 1141, 934, 207, 957, 660, 670, 28, 1066, 1252, 1932, 669, 906, 1904, 1820, 2043, 881], + [1599, 1031, 1474, 336, 1540, 571, 437, 1440, 1616, 1365, 1412, 1246, 400, 405, 1776, 96, 296, 38, 1597, 466, 1630, 1256, 1940, 887, 1769, 294, 285, 842, 1756, 1619, 451, 1529], + [1615, 339, 1722, 525, 942, 105, 1365, 670, 785, 1316, 465, 1860, 438, 968, 547, 1938, 1816, 1429, 1065, 1942, 660, 1446, 1093, 1066, 931, 121, 688, 1033, 1178, 754, 1783, 94], + [912, 1354, 598, 254, 341, 1980, 1166, 585, 1302, 473, 554, 242, 174, 2030, 2011, 325, 978, 1690, 258, 396, 1831, 1768, 1291, 1699, 2001, 433, 1414, 2012, 1045, 511, 533, 1104], + [80, 1791, 1062, 1136, 391, 568, 1651, 101, 959, 2043, 1683, 760, 794, 181, 570, 540, 1599, 20, 1017, 973, 1654, 396, 586, 778, 2044, 1664, 1911, 929, 66, 897, 510, 643], + [1161, 1093, 161, 1296, 589, 54, 906, 981, 1927, 605, 516, 1731, 1461, 1204, 1902, 920, 1488, 177, 805, 1402, 610, 1446, 1154, 1067, 2025, 645, 762, 1715, 415, 1658, 1713, 1607], + [374, 1444, 1577, 792, 1450, 628, 604, 1729, 322, 514, 1725, 540, 1070, 575, 653, 800, 250, 187, 569, 349, 354, 1573, 176, 793, 897, 359, 536, 276, 1224, 23, 145, 1287], + [1184, 415, 1644, 1737, 1788, 385, 784, 1861, 1172, 1118, 367, 1156, 234, 1946, 1742, 981, 828, 1798, 1821, 361, 1148, 670, 518, 1288, 761, 1050, 1642, 1006, 1747, 840, 1599, 720], + [1141, 1731, 1670, 1542, 1347, 1907, 683, 753, 1347, 68, 2031, 153, 556, 719, 736, 1759, 1131, 1073, 1747, 1730, 1487, 1137, 1869, 1624, 699, 1900, 748, 49, 1312, 735, 726, 1268], + [1141, 1383, 405, 1033, 490, 488, 1102, 471, 713, 1630, 447, 703, 1495, 1001, 1855, 354, 456, 411, 786, 853, 168, 407, 116, 699, 605, 128, 532, 1076, 208, 447, 1448, 1071], + [345, 1013, 948, 1728, 1837, 337, 930, 1226, 1643, 1729, 983, 1688, 2009, 435, 1358, 721, 42, 1779, 1332, 1077, 1873, 128, 1327, 125, 1226, 1704, 705, 1459, 1449, 862, 155, 1870], + [336, 904, 684, 184, 1542, 714, 1752, 1180, 1373, 1816, 504, 1716, 1066, 1086, 1212, 530, 1413, 1278, 75, 1347, 82, 1623, 1307, 1717, 1861, 494, 888, 1589, 670, 1999, 905, 1430], + [578, 554, 14, 523, 1016, 300, 1589, 1017, 356, 1583, 1654, 414, 449, 376, 1413, 58, 706, 963, 388, 1626, 131, 352, 1024, 1054, 2025, 1561, 77, 1589, 1486, 431, 1249, 1508], + [184, 2043, 169, 1673, 580, 162, 1752, 397, 1119, 2009, 697, 150, 1475, 157, 1523, 1402, 575, 86, 1373, 1230, 1564, 1308, 626, 1093, 1603, 1446, 1390, 1543, 1778, 1142, 1357, 1831], + [1484, 1987, 932, 1728, 1504, 1618, 291, 1865, 1151, 460, 1792, 141, 234, 2043, 829, 513, 435, 791, 1037, 1541, 65, 424, 1589, 1711, 312, 1306, 212, 686, 673, 984, 1914, 1549], + [513, 1536, 1844, 1319, 572, 1069, 121, 735, 1949, 1211, 1362, 1027, 105, 1379, 315, 1782, 706, 1658, 1510, 1989, 1443, 1690, 822, 1614, 1194, 1460, 992, 2040, 1178, 1474, 1110, 1326], + [1858, 194, 1594, 1935, 1622, 1892, 1577, 137, 1907, 2015, 757, 414, 1823, 836, 496, 530, 1385, 1503, 1065, 1554, 664, 525, 1031, 433, 69, 466, 1016, 1846, 1609, 1658, 911, 94], + [1134, 1744, 323, 691, 1837, 347, 1871, 172, 811, 91, 1883, 436, 1912, 23, 1336, 1684, 519, 1612, 1219, 1402, 728, 1953, 1658, 641, 27, 1340, 436, 139, 2008, 1030, 159, 324], + [1270, 1536, 1639, 414, 1387, 1170, 1067, 1701, 1414, 505, 1122, 36, 1731, 350, 1552, 1214, 1444, 30, 107, 172, 480, 1858, 655, 168, 1107, 691, 1272, 797, 1656, 548, 1407, 1375], + [1270, 286, 1371, 1552, 1622, 1739, 1348, 2018, 345, 1537, 1941, 2024, 1423, 740, 284, 513, 91, 1228, 2015, 385, 992, 39, 813, 803, 2025, 497, 663, 462, 1609, 334, 927, 1470], + [1718, 994, 265, 1421, 1622, 1098, 845, 1868, 832, 459, 447, 619, 1970, 929, 513, 63, 1448, 1509, 1219, 1942, 285, 1373, 1259, 1004, 11, 1040, 1984, 57, 188, 1687, 1475, 805], + [1157, 832, 480, 1225, 1019, 347, 326, 999, 125, 1542, 118, 1383, 1343, 1077, 1821, 1602, 1978, 1642, 618, 808, 692, 1953, 1353, 963, 619, 1291, 1016, 1458, 1995, 1688, 1872, 1718], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ]]) + # fmt: on + + torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS) + + @slow + @require_torch_gpu + def test_1b_model_integration_generate_multiple_audio(self): + """ + Test the generated tokens match the ones from the original model implementation. + Such tokens are to be retreived using https://gist.github.com/eustlb/0c94de002e1325abb61d32217f74c0f8, which is a script that infers the original model. + """ + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + + ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + conversation = [] + + # context + for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): + conversation.append( + { + "role": f"{speaker_id}", + "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], + } + ) + + # text prompt + conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]}) + + inputs = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + ).to(torch_device) + + model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device) + output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False) + + # fmt: off + EXPECTED_OUTPUT_TOKENS = torch.tensor([[ + [420, 1189, 1311, 318, 359, 694, 1550, 1044, 1614, 1437, 1978, 537, 554, 1681, 147, 1225, 422, 1357, 1681, 1619, 165, 641, 1132, 1975, 1568, 406, 756, 503, 1673, 1428, 762, 781], + [1848, 1412, 957, 1656, 871, 540, 1999, 175, 711, 1383, 1814, 104, 742, 1285, 733, 1251, 1165, 1915, 1392, 645, 1804, 913, 1772, 632, 376, 1507, 1132, 725, 716, 1121, 1769, 1509], + [429, 1138, 895, 1018, 1099, 257, 1395, 1015, 576, 1599, 497, 19, 1858, 1437, 282, 357, 1143, 828, 1481, 70, 985, 551, 935, 278, 1102, 1453, 1902, 755, 526, 498, 1441, 1733], + [546, 343, 1547, 879, 2039, 692, 1999, 1150, 1969, 1866, 1178, 199, 1913, 1738, 1530, 1728, 1193, 74, 695, 612, 1095, 1597, 1381, 683, 1385, 2045, 1069, 865, 438, 70, 1437, 318], + [1741, 1621, 733, 1580, 1006, 1790, 1031, 1563, 569, 1822, 1229, 854, 142, 1554, 792, 741, 147, 552, 731, 772, 908, 831, 1291, 1819, 296, 290, 1871, 100, 1904, 1420, 1903, 1653], + [1264, 1576, 963, 12, 1403, 453, 259, 1359, 1270, 466, 1744, 1579, 1081, 1691, 1495, 1293, 110, 1020, 2042, 189, 1358, 955, 784, 1317, 2, 1794, 388, 376, 327, 511, 866, 1308], + [1407, 1412, 1665, 1683, 284, 874, 1859, 326, 1491, 1343, 777, 695, 1424, 396, 274, 202, 178, 747, 470, 1805, 1414, 2000, 127, 1884, 531, 215, 1322, 1098, 1674, 1227, 1092, 204], + [584, 637, 1665, 1683, 1136, 1201, 212, 310, 1441, 1619, 190, 1611, 1629, 2011, 1754, 1587, 413, 1287, 1251, 1382, 1904, 444, 1665, 1047, 1982, 1169, 1200, 809, 117, 327, 958, 1877], + [471, 1469, 1679, 1184, 343, 974, 1442, 897, 1888, 1468, 1092, 1398, 1714, 963, 1577, 1797, 766, 565, 403, 920, 1806, 466, 1193, 446, 825, 775, 1886, 1095, 159, 1085, 858, 504], + [28, 1511, 1510, 1580, 447, 1934, 1031, 1439, 202, 1435, 474, 1731, 724, 1080, 1121, 421, 625, 1410, 95, 605, 815, 1825, 127, 785, 900, 1673, 178, 1242, 2033, 1230, 350, 139], + [20, 1215, 253, 955, 871, 1689, 1986, 24, 1648, 423, 562, 1937, 1146, 26, 1266, 346, 188, 318, 179, 1164, 1100, 1978, 478, 1192, 715, 392, 1837, 425, 1492, 766, 1651, 822], + [1879, 1401, 1444, 723, 1754, 732, 1307, 702, 1768, 2013, 1284, 577, 1287, 1532, 647, 189, 903, 587, 800, 152, 898, 182, 2016, 639, 1074, 1220, 1934, 264, 250, 745, 1652, 536], + [1874, 1526, 232, 1580, 1980, 988, 1623, 341, 1768, 956, 1430, 1667, 1687, 1289, 826, 1378, 173, 1466, 479, 835, 1786, 1671, 328, 131, 815, 871, 379, 1329, 440, 1117, 392, 272], + [1762, 426, 1350, 1590, 314, 190, 1514, 344, 1926, 822, 534, 523, 703, 36, 379, 494, 464, 1886, 1555, 1318, 1654, 1469, 1976, 304, 218, 655, 1826, 958, 502, 326, 1898, 861], + [1577, 386, 503, 1492, 698, 405, 1031, 349, 1804, 2012, 1450, 996, 1140, 26, 449, 33, 1917, 354, 702, 1255, 1942, 1184, 864, 2045, 514, 744, 466, 54, 37, 486, 362, 525], + [1109, 1920, 445, 1719, 1670, 1220, 745, 40, 171, 1921, 999, 104, 489, 1911, 883, 306, 649, 1751, 762, 1183, 1085, 1112, 1912, 2035, 1940, 1129, 1592, 1276, 1570, 1236, 738, 209], + [1837, 990, 1063, 318, 1398, 1838, 1678, 906, 754, 802, 562, 353, 1389, 207, 1319, 1188, 2013, 1079, 888, 1706, 1042, 657, 482, 953, 94, 2007, 871, 485, 1596, 275, 410, 1855], + [872, 974, 1344, 1798, 655, 805, 1604, 1913, 455, 615, 1827, 966, 1330, 1826, 1285, 359, 544, 221, 1538, 1658, 374, 1352, 1714, 1925, 235, 65, 350, 931, 1009, 1164, 218, 736], + [1547, 617, 1622, 740, 655, 265, 1324, 1265, 1449, 482, 1037, 105, 1128, 701, 1866, 1674, 1999, 1302, 985, 1942, 663, 449, 1881, 698, 805, 1446, 1742, 1192, 1623, 605, 948, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ]]) + # fmt: on + + torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS) + + @slow + @require_torch_gpu + def test_1b_model_integration_generate_batched(self): + """ + Test the generated tokens match the ones from the original model implementation. + Such tokens are to be retreived using https://gist.github.com/eustlb/bcc532b53161bc31da3d66cb07ae193f, which is a script that infers the original model. + """ + processor = AutoProcessor.from_pretrained(self.model_checkpoint) + + ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") + conversation = [ + [ + { + "role": f"{ds[0]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[0]["text"]}, + {"type": "audio", "path": ds[0]["audio"]["array"]}, + ], + }, + { + "role": f"{ds[1]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[1]["text"]}, + ], + }, + ], + [ + { + "role": f"{ds[0]['speaker_id']}", + "content": [ + {"type": "text", "text": ds[0]["text"]}, + ], + } + ], + ] + + inputs = processor.apply_chat_template( + conversation, + tokenize=True, + return_dict=True, + ).to(torch_device) + + model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device) + output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False) + + # fmt: off + EXPECTED_OUTPUT_TOKENS = torch.tensor([ + [ + [1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601], + [1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 1656, 80, 1947, 1666, 933, 1950, 1544, 1577, 1612, 1791, 1883, 765], + [778, 645, 830, 1051, 524, 1704, 1805, 1438, 211, 906, 691, 814, 1798, 1642, 1042, 284, 1906, 1513, 520, 137, 1052, 1548, 423, 1564, 330, 873, 1381, 188, 317, 1503, 1707, 1744], + [1416, 864, 242, 1653, 604, 1577, 202, 1808, 926, 1867, 204, 134, 1096, 1765, 496, 1680, 268, 1796, 2024, 1989, 583, 183, 952, 105, 765, 1534, 669, 895, 2008, 11, 1199, 195], + [1356, 796, 25, 1580, 15, 344, 1730, 99, 1330, 315, 955, 1964, 1731, 543, 1159, 1860, 671, 732, 63, 382, 143, 395, 1749, 1421, 1640, 1340, 650, 100, 171, 1346, 41, 806], + [1860, 1835, 823, 388, 254, 1734, 1135, 324, 1508, 983, 937, 1703, 1541, 875, 1319, 799, 1259, 1175, 1295, 807, 261, 760, 1916, 1606, 1616, 1894, 1605, 441, 387, 167, 2016, 222], + [1165, 919, 1318, 54, 1727, 1766, 777, 1128, 623, 353, 1840, 241, 977, 424, 1055, 898, 395, 655, 1695, 1084, 1346, 616, 1028, 1927, 603, 858, 758, 1539, 0, 1655, 1853, 1661], + [902, 1746, 1318, 298, 1982, 1184, 775, 328, 1676, 871, 133, 1374, 1927, 1984, 698, 1037, 100, 1884, 1596, 429, 1794, 2046, 105, 2037, 1767, 178, 176, 1293, 1893, 1780, 1832, 1382], + [1932, 714, 1084, 1167, 624, 509, 1213, 651, 1000, 1686, 1537, 555, 461, 623, 1433, 1089, 1212, 1628, 834, 1111, 943, 1816, 1947, 1063, 354, 1843, 1741, 2015, 404, 928, 1488, 168], + [1437, 314, 1356, 404, 1274, 2016, 998, 1350, 155, 553, 368, 1501, 1431, 1563, 1105, 1353, 535, 908, 1305, 1214, 1656, 65, 1469, 1517, 480, 252, 1289, 696, 302, 632, 246, 72], + [724, 848, 1140, 927, 1669, 296, 447, 1708, 1898, 685, 1041, 1685, 708, 1510, 1623, 876, 11, 99, 43, 586, 1705, 1753, 1477, 1191, 583, 1249, 1613, 992, 1319, 677, 418, 668], + [925, 54, 1810, 674, 1306, 848, 573, 1772, 105, 301, 1753, 989, 440, 1057, 823, 1313, 1663, 750, 1477, 102, 1437, 1114, 399, 1440, 319, 118, 1827, 295, 1429, 139, 1594, 55], + [629, 149, 784, 838, 984, 604, 685, 1229, 1432, 859, 1526, 1336, 1949, 281, 988, 1260, 52, 6, 1216, 1542, 1426, 1938, 253, 280, 1319, 794, 901, 843, 615, 437, 814, 20], + [1281, 502, 1237, 404, 625, 1444, 397, 1999, 2016, 1686, 533, 1785, 1152, 1245, 579, 1906, 1204, 549, 1334, 536, 1351, 1979, 208, 111, 2011, 751, 677, 1948, 1772, 1525, 2038, 419], + [9, 490, 869, 2026, 1928, 1489, 587, 549, 1241, 460, 1458, 1636, 924, 222, 1246, 480, 706, 398, 75, 1717, 604, 1446, 333, 237, 805, 1446, 421, 1343, 78, 1260, 1872, 1116], + [1944, 755, 375, 332, 1464, 828, 1273, 579, 1457, 353, 1510, 1910, 1609, 705, 400, 1666, 227, 1544, 1270, 136, 1857, 1975, 1762, 2006, 1102, 221, 1965, 151, 2041, 198, 1830, 287], + [221, 502, 440, 247, 181, 1912, 42, 357, 1883, 596, 919, 953, 1774, 772, 915, 188, 438, 1226, 544, 1313, 726, 1298, 85, 677, 566, 1581, 30, 341, 878, 1732, 591, 1446], + [1178, 1690, 320, 1746, 1798, 685, 1941, 666, 832, 623, 1907, 128, 337, 1779, 824, 923, 1041, 287, 1165, 437, 1803, 1222, 870, 646, 358, 220, 2009, 735, 468, 1908, 1349, 1603], + [1432, 1286, 540, 1687, 1741, 951, 299, 1233, 1061, 1128, 985, 953, 1917, 198, 2031, 1559, 1096, 1455, 780, 437, 163, 1268, 649, 1029, 1081, 1518, 304, 1638, 814, 364, 140, 1385], + [905, 463, 1739, 1063, 351, 936, 1652, 101, 1323, 1731, 298, 1193, 266, 1554, 1837, 1659, 409, 1739, 1012, 725, 851, 1909, 213, 1918, 1759, 1561, 1250, 970, 1571, 352, 911, 195], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ], + [ + [1375, 203, 265, 164, 200, 1867, 976, 924, 1972, 1637, 1048, 271, 1912, 1430, 853, 1942, 260, 1642, 400, 57, 1376, 1626, 1821, 1163, 619, 777, 1076, 951, 389, 1820, 84, 1417], + [914, 527, 286, 968, 305, 1314, 805, 1703, 87, 559, 1980, 1124, 1726, 36, 1139, 618, 1628, 519, 1943, 781, 400, 1265, 438, 113, 87, 856, 465, 162, 1099, 352, 1141, 274], + [1408, 6, 126, 2009, 90, 996, 934, 134, 1857, 126, 602, 876, 1092, 1962, 1205, 828, 707, 1063, 393, 1533, 123, 1086, 1749, 1324, 1, 1763, 1707, 1191, 34, 1323, 1017, 1787], + [1000, 683, 1630, 703, 1574, 587, 25, 1049, 213, 1270, 1641, 1072, 1892, 1634, 1603, 90, 867, 2037, 1021, 715, 206, 507, 1138, 959, 1822, 1785, 280, 1100, 1660, 251, 1903, 988], + [1657, 1981, 246, 1048, 1952, 451, 305, 423, 2000, 416, 756, 1748, 7, 748, 1866, 1795, 1682, 1832, 338, 212, 1685, 518, 154, 1407, 416, 765, 776, 25, 55, 458, 612, 262], + [1034, 564, 667, 1474, 1212, 350, 712, 941, 1151, 1182, 1280, 640, 924, 1722, 1816, 458, 226, 359, 1518, 102, 1203, 459, 676, 1788, 1110, 393, 1974, 1721, 795, 1459, 798, 1723], + [742, 1616, 119, 653, 441, 679, 246, 1432, 486, 1615, 1191, 500, 650, 223, 687, 1765, 1875, 963, 1385, 863, 151, 1771, 458, 1170, 737, 1932, 785, 1954, 1067, 16, 1986, 2029], + [1437, 1078, 1767, 1452, 1392, 45, 2010, 1664, 245, 2015, 1416, 1055, 457, 985, 740, 1594, 1562, 1838, 258, 1431, 701, 604, 1813, 352, 792, 632, 21, 895, 70, 609, 850, 1599], + [983, 1961, 54, 135, 846, 711, 473, 1630, 1373, 1094, 251, 525, 632, 1014, 1594, 1594, 1752, 398, 1266, 1357, 942, 1680, 191, 874, 483, 1291, 381, 1873, 1964, 1278, 1477, 122], + [1663, 1969, 1887, 113, 145, 251, 1133, 156, 245, 1641, 209, 1322, 2037, 836, 539, 667, 940, 797, 1758, 1357, 191, 1137, 587, 1699, 27, 701, 395, 99, 1682, 876, 762, 839], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ]) + # fmt: on + + torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS) diff --git a/tests/models/csm/test_processor_csm.py b/tests/models/csm/test_processor_csm.py new file mode 100644 index 0000000000..da96381246 --- /dev/null +++ b/tests/models/csm/test_processor_csm.py @@ -0,0 +1,140 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import shutil +import tempfile +import unittest + +import jinja2 +import numpy as np + +from transformers import CsmProcessor +from transformers.testing_utils import require_torch +from transformers.utils import is_torch_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_torch_available(): + import torch + + +@require_torch +class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = CsmProcessor + + @classmethod + def setUpClass(cls): + # TODO: @eustlb, change for hf-internal-testing/csm-1b + cls.checkpoint = "eustlb/csm-1b" + processor = CsmProcessor.from_pretrained(cls.checkpoint) + cls.audio_token = processor.audio_token + cls.audio_token_id = processor.audio_token_id + cls.pad_token_id = processor.tokenizer.pad_token_id + cls.bos_token_id = processor.tokenizer.bos_token_id + cls.tmpdirname = tempfile.mkdtemp() + processor.save_pretrained(cls.tmpdirname) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + + def prepare_processor_dict(self): + return {"chat_template": "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"} # fmt: skip + + def test_chat_template_is_saved(self): + processor_loaded = self.processor_class.from_pretrained(self.tmpdirname) + processor_dict_loaded = json.loads(processor_loaded.to_json_string()) + # chat templates aren't serialized to json in processors + self.assertFalse("chat_template" in processor_dict_loaded.keys()) + + # they have to be saved as separate file and loaded back from that file + # so we check if the same template is loaded + processor_dict = self.prepare_processor_dict() + self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + + def test_apply_chat_template(self): + # Message contains content which a mix of lists with images and image urls and string + messages = [ + { + "role": "0", + "content": [ + {"type": "text", "text": "This is a test sentence 0."}, + {"type": "audio"}, + ], + }, + { + "role": "1", + "content": [ + {"type": "text", "text": "This is a test sentence 1."}, + {"type": "audio"}, + ], + }, + { + "role": "0", + "content": [ + {"type": "text", "text": "This is a prompt."}, + ], + }, + ] + processor = CsmProcessor.from_pretrained(self.tmpdirname) + rendered = processor.apply_chat_template(messages, tokenize=False) + + expected_rendered = ( + "<|begin_of_text|>[0]This is a test sentence 0.<|end_of_text|>" + "<|AUDIO|><|audio_eos|>" + "<|begin_of_text|>[1]This is a test sentence 1.<|end_of_text|>" + "<|AUDIO|><|audio_eos|>" + "<|begin_of_text|>[0]This is a prompt.<|end_of_text|>" + ) + self.assertEqual(rendered, expected_rendered) + + messages = [ + { + "role": "0", + "content": [ + {"type": "text", "text": "This is a test sentence."}, + ], + }, + { + "role": "1", + "content": [ + {"type": "text", "text": "This is a test sentence."}, + ], + }, + ] + + # this should raise an error because the CSM processor requires audio content in the messages expect the last one + with self.assertRaises(jinja2.exceptions.TemplateError): + input_ids = processor.apply_chat_template(messages, tokenize=False) + + # now let's very that it expands audio tokens correctly + messages = [ + { + "role": "0", + "content": [ + {"type": "text", "text": "This is a test sentence."}, + {"type": "audio", "audio": np.zeros(4096)}, + ], + }, + ] + + input_ids = processor.apply_chat_template(messages, tokenize=True) + + # 4096 audio input values should give 3 audio tokens + expected_ids = torch.tensor( + [[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]] + ) + torch.testing.assert_close(input_ids, expected_ids) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 62508002fe..974bfb7b5a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4350,8 +4350,8 @@ class ModelTesterMixin: out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens # comparing softmax-normalized logits: - normalized_0 = F.softmax(out_last_tokens) - normalized_1 = F.softmax(out_shared_prefix_last_tokens) + normalized_0 = F.softmax(out_last_tokens, dim=-1) + normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1) torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) @slow @@ -4403,7 +4403,7 @@ class ModelTesterMixin: self.skipTest(reason="This model does not support `logits_to_keep` argument.") config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - batch_size, sequence_length = inputs["input_ids"].shape + batch_size, sequence_length = inputs["input_ids"].shape[:2] vocab_size = config.get_text_config().vocab_size model = model_class(config).to(device=torch_device).eval() # some models have labels but `logits_to_keep` should not be used in train mode diff --git a/utils/check_repo.py b/utils/check_repo.py index 1164ac9db0..960a73e734 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -159,6 +159,9 @@ IGNORE_NON_TESTED = ( "InternVLVisionModel", # Building part of bigger (tested) model "JanusVisionModel", # Building part of bigger (tested) model "TimesFmModel", # Building part of bigger (tested) model + "CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. + "CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. + "CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest. ] ) @@ -368,6 +371,10 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "Qwen2_5OmniToken2WavModel", # Building part of a bigger model "Qwen2_5OmniToken2WavBigVGANModel", # Building part of a bigger model "Qwen2_5OmniToken2WavDiTModel", # Building part of a bigger model + "CsmBackboneModel", # Building part of a bigger model + "CsmDepthDecoderModel", # Building part of a bigger model + "CsmDepthDecoderForCausalLM", # Building part of a bigger model + "CsmForConditionalGeneration", # Building part of a bigger model ] # DO NOT edit this list!