Add CSM model (#36719)
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled

* draft structure

* depth decoder with forward pre hook

* full model forward draft

* draft update

* depth decoder update

* ConversationalSpeechModelForCausalLM udpates

* add generate

* max length criteria small fix

* udpate

* updates

* generation update

* update in loss compute

* conversion script

* update for correct input embeddings

* handle interleaved rope

* update

* update

* update

* support compile

* update training

* add doc

* update doc

* correct inits

* ConversationalSpeechModel -> Csm

* conf update

* name update

* tests CsmForCausalLMTest

* convert use cached_file

* conf + modeling updates

* generate utils handle third dim shape

* integration test

* modeling + conf updates

* common test handle more than 2 dims

* add nested audio list utils

* processing handle nested audio list

* csm processing draft

* mimi util

* init updates

* modular update

* convert modular

* processing update

* csm tests update

* generate tests handle third dim

* generate utils handle third dim

* propagate _get_initial_cache_position update

* tied_weight_keys update + convert correctly

* fix inputs_embeds

* revert audio nested list

* batch inference update + return audio

* audio_utils update

* processor update

* some more integration tests

* remove old test

* porcessing output labels

* improve

* fix

* update rope values with equivalent ones

* conversion update

* udpate tests

* handle depth decoder generation config

* remove default eos_token_id

* make style

* revert modeling_mimi

* add default generation_config

* remove sdpa since handled by default

* make

* fix conflict

* fix conflicts

* correct naming

* correct imports

* make

* causal -> conditional naming

* causal -> conditional naming

* auto update

* make

* make

* add doc

* test update

* fix weight init

* audio tokens offsets as buffer

* 4d mask in conditional class

* make

* doc update

* fix causal mask

* fix causal mask

* doc update

* doc update

* add processor doc

* update doc

* fix 4d causal mask

* update make_list_of_audio

* do not default to mutable

* remove duplicates

* remove useless reset_parameters

* use GradientCheckpointingLayer

* use can_return_tuple

* formatting

* prepend placeholder in _sample

* torch compile fix

* some more fixies

* convert modular

* fix

* default max_length in convert

* handle depth decoder generation config correctly

* clearer formulation

* handle output_loading_info

* handle softmax warning

* add doc

* propagate _get_initial_cache_position changes

* generation in its own module

* add processor tests

* fix compile witu cuda graphs

* fix compile with cuda graphs

* add csm.md

* include CSM loss

* doc nit

* doc nit

* doc nit

* Update docs/source/en/model_doc/csm.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add save_audio to processor

* Update src/transformers/models/csm/modular_csm.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* doc update

* simplify audio_codes_mask computation

* doc update

* simplify loss computation

* fix static cache test

* fix

* remove comment

* simplify encoded length computation

* use hf-internal-testing

* doc update

* cast to float before numpy

* nit

* mem efficient codebook head

* nit

* cat input values with cutoffs

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
eustlb
2025-05-07 10:20:13 -04:00
committed by GitHub
parent c8607a17cb
commit 798f948e88
29 changed files with 5827 additions and 86 deletions

View File

@@ -501,9 +501,9 @@ class GenerationTesterMixin:
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
@@ -525,13 +525,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
@@ -565,10 +565,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@@ -582,9 +582,9 @@ class GenerationTesterMixin:
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_sample_generate_dict_output(self):
@@ -607,13 +607,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
@@ -632,9 +632,9 @@ class GenerationTesterMixin:
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_beam_search_generate_dict_output(self):
@@ -657,13 +657,13 @@ class GenerationTesterMixin:
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -706,10 +706,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(
@@ -759,9 +759,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
@@ -786,13 +786,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -840,9 +840,9 @@ class GenerationTesterMixin:
beam_kwargs=beam_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
# check `group_beam_search` for higher than 1 `num_return_sequences`
num_return_sequences = 2
@@ -853,9 +853,9 @@ class GenerationTesterMixin:
beam_kwargs=beam_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
@@ -878,13 +878,13 @@ class GenerationTesterMixin:
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -923,9 +923,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@@ -947,9 +947,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@@ -987,13 +987,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@@ -1031,9 +1031,9 @@ class GenerationTesterMixin:
use_cache=True, # Enable cache
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_contrastive_generate_dict_outputs_use_cache(self):
@@ -1067,10 +1067,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@@ -1499,7 +1499,7 @@ class GenerationTesterMixin:
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
@@ -1525,10 +1525,12 @@ class GenerationTesterMixin:
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
padded_attention_mask = torch.cat(
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
@@ -1587,7 +1589,7 @@ class GenerationTesterMixin:
else text_config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
default_cross_attention_shape = (
@@ -1606,7 +1608,7 @@ class GenerationTesterMixin:
for _ in range(num_decoder_layers)
]
else:
batch_size, seq_length = inputs["input_ids"].shape
batch_size, seq_length = inputs["input_ids"].shape[:2]
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
all_cache_shapes = [
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
@@ -1727,7 +1729,7 @@ class GenerationTesterMixin:
"min_new_tokens": 5, # generate exactly 5 tokens
}
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same.
@@ -2262,11 +2264,11 @@ class GenerationTesterMixin:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
@@ -2408,7 +2410,7 @@ class GenerationTesterMixin:
config = config.text_config if hasattr(config, "text_config") else config
generated_length = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
output.sequences.shape[1] - 1 if config.is_encoder_decoder else output.sequences.shape[1] - prompt_length
)
decoder_past_key_values = getattr(output, "past_key_values", None)
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
@@ -2441,7 +2443,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
attentions=output.decoder_attentions,
prompt_length=1, # the BOS token
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
decoder_past_key_values=decoder_past_key_values,
)
@@ -2450,7 +2452,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
attentions=output.attentions,
prompt_length=prompt_length,
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
decoder_past_key_values=decoder_past_key_values,
)
@@ -2469,7 +2471,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
hidden_states=output.decoder_hidden_states,
prompt_length=1, # the BOS token
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
use_cache=use_cache,
)
@@ -2478,7 +2480,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
hidden_states=output.hidden_states,
prompt_length=prompt_length,
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
use_cache=use_cache,
)
@@ -2506,7 +2508,7 @@ class GenerationTesterMixin:
)
if has_standard_cache:
if use_cache:
cache_length = output.sequences.shape[-1] - 1
cache_length = output.sequences.shape[1] - 1
self._check_past_key_values_for_generate(
batch_size=internal_batch_size,
decoder_past_key_values=decoder_past_key_values,