Add Dia model (#38405)

* add dia model

* add tokenizer files

* cleanup some stuff

* brut copy paste code

* rough cleanup of the modeling code

* nuke some stuff

* more nuking

* more cleanups

* updates

* add mulitLayerEmbedding vectorization

* nits

* more modeling simplifications

* updates

* update rope

* update rope

* just fixup

* update configuration files

* more cleanup!

* default config values

* update

* forgotten comma

* another comma!

* update, more cleanups

* just more nits

* more config cleanups

* time for the encoder

* fix

* sa=mall nit

* nits

* n

* refacto a bit

* cleanup

* update cv scipt

* fix last issues

* fix last nits

* styling

* small fixes

* just run 1 generation

* fixes

* nits

* fix conversion

* fix

* more fixes

* full generate

* ouf!

* fixes!

* updates

* fix

* fix cvrt

* fixup

* nits

* delete wrong test

* update

* update

* test tokenization

* let's start changing things bit by bit - fix encoder step

* removing custom generation, moving to GenerationMixin

* add encoder decoder attention masks for generation

* mask changes, correctness checked against ad29837 in dia repo

* refactor a bit already --> next cache

* too important not to push :)

* minimal cleanup + more todos

* make main overwrite modeling utils

* add cfg filter & eos filter

* add eos countdown & delay pattern

* update eos countdown

* add max step eos countdown

* fix tests

* fix some things

* fix generation with testing

* move cfg & eos stuff to logits processor

* make RepetitionPenaltyLogitsProcessor flexible

- can accept 3D scores like (batch_size, channel, vocab)

* fix input_ids concatenation dimension in GenerationMixin for flexibility

* Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty classes; refactor logits processing in DiaForConditionalGeneration to utilize new configurations and improve flexibility.

* Add stopping criteria

* refactor

* move delay pattern from processor to modeling like musicgen.

- add docs
- change eos countdown to eos delay pattern

* fix processor & fix tests

* refactor types

* refactor imports

* format code

* fix docstring to pass ci

* add docstring to DiaConfig & add DiaModel to test

* fix docstring

* add docstring

* fix some bugs

* check

* porting / merging results from other branch - IMPORTANT: it very likely breaks generation, the goal is to have a proper forward path first

* experimental testing of left padding for first channel

* whoops

* Fix merge to make generation work

* fix cfg filter

* add position ids

* add todos, break things

* revert changes to generation --> we will force 2d but go 3d on custom stuff

* refactor a lot, change prepare decoder ids to work with left padding (needs testing), add todos

* some first fixes to get to 10. in generation

* some more generation fixes / adjustment

* style + rope fixes

* move cfg out, simplify a few things, more todos

* nit

* start working on custom logit processors

* nit

* quick fixes

* cfg top k

* more refactor of logits processing, needs a decision if gen config gets the new attributes or if we move it to config or similar

* lets keep changes to core code minimal, only eos scaling is questionable atm

* simpler eos delay logits processor

* that was for debugging :D

* proof of concept rope

* small fix on device mismatch

* cfg fixes + delay logits max len

* transformers rope

* modular dia

* more cleanup

* keep modeling consistently 3D, generate handles 2D internally

* decoder starts with bos if nothing

* post processing prototype

* style

* lol

* force sample / greedy + fixes on padding

* style

* fixup tokenization

* nits

* revert

* start working on dia tests

* fix a lot of tests

* more test fixes

* nit

* more test fixes + some features to simplify code more

* more cleanup

* forgot that one

* autodocs

* small consistency fixes

* fix regression

* small fixes

* dia feature extraction

* docs

* wip processor

* fix processor order

* processing goes brrr

* transpose before

* small fix

* fix major bug but needs now a closer look into the custom processors esp cfg

* small thing on logits

* nits

* simplify indices and shifts

* add simpler version of padding tests back (temporarily)

* add logit processor tests

* starting tests on processor

* fix mask application during generation

* some fixes on the weights conversion

* style + fixup logits order

* simplify conversion

* nit

* remove padding tests

* nits on modeling

* hmm

* fix tests

* trigger

* probably gonna be reverted, just a quick design around audio tokenizer

* fixup typing

* post merge + more typing

* initial design for audio tokenizer

* more design changes

* nit

* more processor tests and style related things

* add to init

* protect import

* not sure why tbh

* add another protect

* more fixes

* wow

* it aint stopping :D

* another missed type issue

* ...

* change design around audio tokenizer to prioritize init and go for auto - in regards to the review

* change to new causal mask function + docstrings

* change ternary

* docs

* remove todo, i dont think its essential tbh

* remove pipeline as current pipelines do not fit in the current scheme, same as csm

* closer to wrapping up the processor

* text to audio, just for demo purposes (will likely be reverted)

* check if it's this

* save audio function

* ensure no grad

* fixes on prefixed audio, hop length is used via preprocess dac, device fixes

* integration tests (tested locally on a100) + some processor utils / fixes

* style

* nits

* another round of smaller things

* docs + some fixes (generate one might be big)

* msytery solved

* small fix on conversion

* add abstract audio tokenizer, change init check to abstract class

* nits

* update docs + fix some processing :D

* change inheritance scheme for audio tokenizer

* delete dead / unnecessary code in copied generate loop

* last nits on new pipeline behavior (+ todo on tests) + style

* trigger

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Vasqu <antonprogamer@gmail.com>
This commit is contained in:
Jaeyong Sung
2025-06-26 20:04:23 +09:00
committed by GitHub
parent 5995cfa0a0
commit 583db52bc6
34 changed files with 5733 additions and 29 deletions

View File

@@ -56,7 +56,12 @@ if is_torch_available():
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
from transformers.generation.logits_process import (
BarkEosPrioritizerLogitsProcessor,
DiaClassifierFreeGuidanceLogitsProcessor,
DiaEOSChannelFilterLogitsProcessor,
DiaEOSDelayPatternLogitsProcessor,
)
@require_torch
@@ -1211,3 +1216,145 @@ class LogitsProcessorTest(unittest.TestCase):
)
)
self.assertTrue(is_close)
def test_dia_classifier_free_guidance(self):
input_ids = torch.LongTensor([[0]])
logits_uncond = torch.tensor([[1.0, 0, 1.5]])
logits_cond = torch.tensor([[1.0, 1.0, 1.0]])
# base cfg with conditioned as center
cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5)
out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
res = logits_cond + 1.5 * (logits_cond - logits_uncond)
self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
# additional top k (on cond logits)
cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5, guidance_top_k=1)
out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
res = logits_cond + 1.5 * (logits_cond - logits_uncond)
mask = res == res.max()
res = logits_cond.clone()
res[~mask.bool()] = -float("inf")
self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
def test_dia_channel_filter(self):
eos = 2
bsz, channels, vocab = 2, 2, 4
input_ids = torch.LongTensor([[0]])
logits = torch.zeros(size=(bsz, channels, vocab)).view(bsz * channels, vocab)
logits[0, eos] = 1 # Eos max (forced)
logits[1, eos] = 1 # Eos max (forced) but not channel 0
channel_filter = DiaEOSChannelFilterLogitsProcessor(num_channels=channels, eos_token_id=eos)
out = channel_filter(input_ids, logits).view(bsz, channels, vocab)
for i in range(vocab):
if i > eos:
# special tokens are not to be predicted
self.assertTrue((out[:, :, i] == -float("inf")).all())
elif i == eos:
# Eos forced on channel 0
self.assertTrue(out[0, 0, i] == 1)
# Eos suppressed on everything else (even if max before)
self.assertTrue(out[0, 1, i] == -float("inf"))
self.assertTrue((out[1, :, i] == -float("inf")).all())
else:
# Eos forced on channel 0
self.assertTrue(out[0, 0, i] == -float("inf"))
# previous values
self.assertTrue(out[0, 1, i] == 0)
self.assertTrue((out[1, :, i] == 0).all())
def test_dia_delay_pattern(self):
def check_eos_logits(out, logits, batch, channel, eos):
for i in range(vocab):
if i == eos:
self.assertTrue(out[batch, channel, i] == 0)
else:
self.assertTrue(out[batch, channel, i] == -float("inf"))
for c in range(channel):
if c != channel:
self.assertTrue((out[batch, c] == logits[batch, c]).all())
eos = 2
delay_pattern = [0, 2, 3]
max_generation_len = 10
bsz, channels, vocab = 2, 3, 4
input_ids = torch.LongTensor([[0]])
logits = torch.zeros(size=(bsz, channels, vocab))
# Ensure that argmax can not result in eos
logits[:, :, eos] = -1
delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
)
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
# Nothing should happen except for init of some attributes
self.assertTrue((out == logits).all())
self.assertTrue((~delay_pattern_processor.active_batches).all())
self.assertTrue(
(delay_pattern_processor.delay_pattern == torch.tensor([delay_pattern for _ in range(bsz)])).all()
)
# Make first batch end
logits[0, 0, eos] = 1
# Go through the complete delay pattern
for i in range(max(delay_pattern) + 1):
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
# no delay should kick in
if i == 1:
self.assertTrue((out == logits).all())
else:
j = i if i == 0 else i - 1
check_eos_logits(out=out, logits=logits, batch=0, channel=j, eos=eos)
self.assertTrue((out[1] == logits[1]).all())
self.assertTrue(delay_pattern_processor.active_batches[0])
self.assertFalse(delay_pattern_processor.active_batches[1])
self.assertTrue(
(
delay_pattern_processor.delay_pattern[0]
== torch.tensor([delay - (i + 1) for delay in delay_pattern])
).all()
)
self.assertTrue((delay_pattern_processor.delay_pattern[1] == torch.tensor(delay_pattern)).all())
# Make second batch end
logits[1, 0, eos] = 1
# Just to check if other batches could work
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
self.assertTrue((out[0] == logits[0]).all())
self.assertTrue(delay_pattern_processor.active_batches.all())
self.assertTrue(
(delay_pattern_processor.delay_pattern[0] == torch.tensor([delay - 5 for delay in delay_pattern])).all()
)
self.assertTrue(
(delay_pattern_processor.delay_pattern[1] == torch.tensor([delay - 1 for delay in delay_pattern])).all()
)
# Last check on max generation length reached (with delay in mind until last channel produces eos)
input_ids = torch.LongTensor([[0] * (max_generation_len - max(delay_pattern) - 1)])
delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
)
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
check_eos_logits(out=out, logits=logits, batch=0, channel=0, eos=eos)
check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos)
self.assertTrue(delay_pattern_processor.active_batches.all())
self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all())