diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index a3c6981861..9ed80cfb0b 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -839,6 +839,8 @@
title: CSM
- local: model_doc/dac
title: dac
+ - local: model_doc/dia
+ title: Dia
- local: model_doc/encodec
title: EnCodec
- local: model_doc/fastspeech2_conformer
diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md
index adab8591e2..0a36c7c0a1 100644
--- a/docs/source/en/model_doc/auto.md
+++ b/docs/source/en/model_doc/auto.md
@@ -350,6 +350,10 @@ The following auto classes are available for the following audio tasks.
[[autodoc]] AutoModelForTextToWaveform
+### AutoModelForAudioTokenization
+
+[[autodoc]] AutoModelForAudioTokenization
+
## Multimodal
The following auto classes are available for the following multimodal tasks.
diff --git a/docs/source/en/model_doc/dia.md b/docs/source/en/model_doc/dia.md
new file mode 100644
index 0000000000..67c4a3be0b
--- /dev/null
+++ b/docs/source/en/model_doc/dia.md
@@ -0,0 +1,162 @@
+
+
+# Dia
+
+
+
+## Overview
+
+Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs).
+It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing.
+Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning).
+
+**Model Architecture:**
+Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as
+rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while
+for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook
+tokens and decodes them back into audio.
+
+## Usage Tips
+
+### Generation with Text
+
+```python
+from transformers import AutoProcessor, DiaForConditionalGeneration
+
+torch_device = "cuda"
+model_checkpoint = "buttercrab/dia-v1-1.6b"
+
+text = ["[S1] Dia is an open weights text to dialogue model."]
+processor = AutoProcessor.from_pretrained(model_checkpoint)
+inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
+
+model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
+outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s
+
+# save audio to a file
+outputs = processor.batch_decode(outputs)
+processor.save_audio(outputs, "example.wav")
+
+```
+
+### Generation with Text and Audio (Voice Cloning)
+
+```python
+from datasets import load_dataset, Audio
+from transformers import AutoProcessor, DiaForConditionalGeneration
+
+torch_device = "cuda"
+model_checkpoint = "buttercrab/dia-v1-1.6b"
+
+ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ds = ds.cast_column("audio", Audio(sampling_rate=44100))
+audio = ds[-1]["audio"]["array"]
+# text is a transcript of the audio + additional text you want as new audio
+text = ["[S1] I know. It's going to save me a lot of money, I hope. [S2] I sure hope so for you."]
+
+processor = AutoProcessor.from_pretrained(model_checkpoint)
+inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
+prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"])
+
+model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
+outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s
+
+# retrieve actually generated audio and save to a file
+outputs = processor.batch_decode(outputs, audio_prompt_len=prompt_len)
+processor.save_audio(outputs, "example_with_audio.wav")
+```
+
+### Training
+
+```python
+from datasets import load_dataset, Audio
+from transformers import AutoProcessor, DiaForConditionalGeneration
+
+torch_device = "cuda"
+model_checkpoint = "buttercrab/dia-v1-1.6b"
+
+ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
+ds = ds.cast_column("audio", Audio(sampling_rate=44100))
+audio = ds[-1]["audio"]["array"]
+# text is a transcript of the audio
+text = ["[S1] I know. It's going to save me a lot of money, I hope."]
+
+processor = AutoProcessor.from_pretrained(model_checkpoint)
+inputs = processor(
+ text=text,
+ audio=audio,
+ generation=False,
+ output_labels=True,
+ padding=True,
+ return_tensors="pt"
+).to(torch_device)
+
+model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
+out = model(**inputs)
+out.loss.backward()
+```
+
+
+This model was contributed by [Jaeyong Sung](https://huggingface.co/buttercrab), [Arthur Zucker](https://huggingface.co/ArthurZ),
+and [Anton Vlasjuk](https://huggingface.co/AntonV). The original code can be found [here](https://github.com/nari-labs/dia/).
+
+
+## DiaConfig
+
+[[autodoc]] DiaConfig
+
+## DiaDecoderConfig
+
+[[autodoc]] DiaDecoderConfig
+
+## DiaEncoderConfig
+
+[[autodoc]] DiaEncoderConfig
+
+## DiaTokenizer
+
+[[autodoc]] DiaTokenizer
+ - __call__
+
+## DiaFeatureExtractor
+
+[[autodoc]] DiaFeatureExtractor
+ - __call__
+
+## DiaProcessor
+
+[[autodoc]] DiaProcessor
+ - __call__
+ - batch_decode
+ - decode
+
+## DiaModel
+
+[[autodoc]] DiaModel
+ - forward
+
+## DiaForConditionalGeneration
+
+[[autodoc]] DiaForConditionalGeneration
+ - forward
+ - generate
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index 155491d6d5..54fa7c7e26 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -271,7 +271,6 @@ class PretrainedConfig(PushToHubMixin):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.sep_token_id = kwargs.pop("sep_token_id", None)
-
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# task specific arguments
diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py
index 8c72279b6e..d4c08e270b 100644
--- a/src/transformers/generation/logits_process.py
+++ b/src/transformers/generation/logits_process.py
@@ -2975,3 +2975,224 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
The expected mean g-value for watermarked text.
"""
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
+
+
+class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original
+ `ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall
+ calculation, e.g. conditioned logits centered, and an additional top k selection
+ option.
+
+
+
+ This logits processor is exclusively compatible with
+ [Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia)
+
+
+
+ Args:
+ guidance_scale (float):
+ The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
+ Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+ prompt, usually at the expense of poorer quality.
+ guidance_top_k (int, *optional*):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep
+ the logits of the combined CFG output, but the conditioned output only.
+ """
+
+ def __init__(self, guidance_scale: float, guidance_top_k: Optional[int] = None):
+ if guidance_scale > 1:
+ self.guidance_scale = guidance_scale
+ else:
+ raise ValueError(
+ "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
+ f"{guidance_scale}."
+ )
+
+ self.guidance_top_k = guidance_top_k
+ if self.guidance_top_k is not None and self.guidance_top_k < 1:
+ raise ValueError(
+ f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}"
+ )
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ # simple check to make sure we have compatible batch sizes between our
+ # logits scores (cond + uncond) and input ids (cond only)
+ if scores.shape[0] != 2 * input_ids.shape[0]:
+ raise ValueError(
+ f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
+ f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
+ f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
+ )
+ # Base CFG with center on cond_logits
+ unguided_bsz = scores.shape[0] // 2
+ cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
+ scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale
+
+ # Optional CFG top k filtering
+ if self.guidance_top_k is not None:
+ # Create top k based on the combined CFG output
+ _, top_k_indices = torch.topk(scores_processed, k=self.guidance_top_k, dim=-1)
+ top_k_mask = torch.ones_like(scores_processed, dtype=torch.bool)
+ top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False)
+ # Only return conditioned logits with top k
+ scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf"))
+
+ return scores_processed
+
+
+class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor):
+ r"""Specialized processor that ensures certain properties around EOS sampling:
+ 1. Only channel 0 can generate EOS
+ 2. If channel 0 has EOS with highest logit, it will be the only candidate
+ 3. If channel 0 has EOS not with highest logit, it will be suppressed
+
+ 2. and 3. are especially important in contexts where we allow sampling to guarantee the
+ respective tokens to be (not) sampled.
+
+
+
+ This logits processor is exclusively compatible with
+ [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
+
+
+
+ Args:
+ num_channels (`int`):
+ Number of audio codebooks. Simplifies access to the first channel on the logits.
+ eos_token_id (`int`):
+ The id of *end-of-sequence* token.
+ """
+
+ def __init__(self, num_channels: int, eos_token_id: int):
+ if num_channels < 1:
+ raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.")
+ if eos_token_id < 1:
+ raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.")
+
+ self.num_channels = num_channels
+ self.eos_id = eos_token_id
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ # Reshape for easier channel indexing [B, C, V]
+ scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
+
+ # EOS filter
+ # 1. Condition: Only the first channel can generate the EOS token
+ # Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...)
+ # (Assumes them to be greater than audio eos token position)
+ scores[:, 1:, self.eos_id :] = torch.full_like(
+ scores[:, 1:, self.eos_id :],
+ fill_value=-float("inf"),
+ )
+ scores[:, 0, self.eos_id + 1 :] = torch.full_like(
+ scores[:, 0, self.eos_id + 1 :],
+ fill_value=-float("inf"),
+ )
+
+ # 2+3 Conditions: Force/Suppress EOS if (not) highest logit
+ # Reshape back to original shape
+ scores = scores.view(-1, scores.shape[-1])
+
+ # Sample highest tokens
+ top_logit_indices = torch.argmax(scores, dim=-1)
+
+ # 2. Force EOS
+ eos_highest_mask = top_logit_indices == self.eos_id
+ mask_eos_highest = torch.zeros_like(scores, dtype=torch.bool)
+ mask_eos_highest[eos_highest_mask, : self.eos_id] = True
+ scores = scores.masked_fill(mask_eos_highest, -float("inf"))
+
+ # 3. Suppress EOS
+ eos_not_highest_mask = top_logit_indices != self.eos_id
+ mask_eos_unless_highest = torch.zeros_like(scores, dtype=torch.bool)
+ mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True
+ scores = scores.masked_fill(mask_eos_unless_highest, -float("inf"))
+
+ return scores
+
+
+class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor):
+ r"""Special logits processor to handle the generation of the EOS token in Dia.
+ This is due to the fact that Dia does not allow the generation of EOS in all
+ channels except the first channel (C0).
+
+ Hence, based on the delay pattern, an EOS is forced after the respective delays
+ in the channels. For example, if the delay pattern is [0, 2, 3, 4]:
+
+ s s+1 s+2 s+3 s+4 s+5 ...
+ | | | | | |
+ C0: EOS PAD PAD PAD PAD PAD ...
+ C1: x x EOS PAD PAD PAD ...
+ C2: x x x EOS PAD PAD ...
+ C3: x x x x EOS PAD ...
+
+ If the first channel generated EOS at step s, channels Cx are forced to generate
+ theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are
+ handled by the `EosTokenCriteria` when an EOS has been detected.
+
+
+
+ This logits processor is exclusively compatible with
+ [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
+
+
+
+ Args:
+ delay_pattern (`List[int]`):
+ The delays per channel in the audio codebooks.
+ eos_token_id (`int`):
+ The id of *end-of-sequence* token.
+ max_generation_len (`int`):
+ The max sequence length that can be generated.
+ device (`str`, *optional*, defaults to `"cpu"`):
+ The device to allocate the tensors on.
+ """
+
+ def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int, device: str = "cpu"):
+ self.num_channels = len(delay_pattern)
+ # Update during first iteration
+ self.active_batches = None
+ self.delay_pattern = torch.tensor(delay_pattern, device=device, dtype=torch.int)[None, :]
+ self.eos_token_id = eos_token_id
+ self.max_generation_len = max_generation_len - max(delay_pattern) - 1
+ self.device = device
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ # Reshape for easier channel indexing [B, C, V]
+ scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
+
+ # Initialize / expand values on first iteration
+ if self.active_batches is None:
+ self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1)
+ self.active_batches = torch.zeros(size=(scores.shape[0],), device=self.device, dtype=torch.bool)
+
+ # Check if eos has been generated in any batch
+ channel_generated_eos = torch.argmax(scores, dim=-1)[:, 0] == self.eos_token_id
+ # Check if max len has been reached
+ reached_max_len = input_ids.shape[1] == self.max_generation_len
+
+ # Update active batches
+ self.active_batches |= channel_generated_eos
+ self.active_batches |= reached_max_len
+
+ # Find channels that need to force eos
+ forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0)
+ # Use indexing to avoid issues on all `False` by having empty tensors in that case
+ idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True)
+
+ # Force eos if delay is kicking in
+ scores[idx_bsz, idx_channel, :] = -float("inf")
+ scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0
+
+ # Reshape back to [B * C, V]
+ scores = scores.reshape(-1, scores.shape[-1])
+
+ # Update amount of delay left for each channel
+ self.delay_pattern -= self.active_batches[:, None].int()
+
+ return scores
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index a5d1be345d..ea2bd32aa3 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -26,6 +26,7 @@ import re
import shutil
import tempfile
import warnings
+from abc import abstractmethod
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
@@ -5884,3 +5885,26 @@ class AttentionInterface(GeneralInterface):
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
+
+
+class PreTrainedAudioTokenizerBase(PreTrainedModel):
+ """
+ Class that additionally defines the behavior of any `audio_tokenizer` to be added.
+ Characteristic for any of them:
+ 1. Encode raw audio into discrete audio codebooks (with x channels)
+ 2. Decode from discrete audio codebooks back to raw audio
+ It is possible that they can decode in different ways given a different representation
+ but they are forced to support 2. nonetheless, e.g. see `DAC`.
+ """
+
+ @abstractmethod
+ def encode(self, input_values: torch.Tensor, *args, **kwargs):
+ """
+ Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels)
+ """
+ pass
+
+ @abstractmethod
+ def decode(self, audio_codes: torch.Tensor, *args, **kwargs):
+ """Decode from discrete audio codebooks back to raw audio"""
+ pass
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 3c0e649f8a..7b2332d89f 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -88,6 +88,7 @@ if TYPE_CHECKING:
from .depth_anything import *
from .depth_pro import *
from .detr import *
+ from .dia import *
from .dialogpt import *
from .diffllama import *
from .dinat import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 8d2109759d..71ad6eaade 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -106,6 +106,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("depth_pro", "DepthProConfig"),
("deta", "DetaConfig"),
("detr", "DetrConfig"),
+ ("dia", "DiaConfig"),
("diffllama", "DiffLlamaConfig"),
("dinat", "DinatConfig"),
("dinov2", "Dinov2Config"),
@@ -478,6 +479,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("depth_pro", "DepthPro"),
("deta", "DETA"),
("detr", "DETR"),
+ ("dia", "Dia"),
("dialogpt", "DialoGPT"),
("diffllama", "DiffLlama"),
("dinat", "DiNAT"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index cf806f39a6..d54ca4b0f5 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -55,6 +55,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("deformable_detr", "DeformableDetrFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
+ ("dia", "DiaFeatureExtractor"),
("dinat", "ViTFeatureExtractor"),
("donut-swin", "DonutFeatureExtractor"),
("dpt", "DPTFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 51a3c3fbbc..add9d09b0e 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -99,6 +99,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
+ ("dia", "DiaModel"),
("diffllama", "DiffLlamaModel"),
("dinat", "DinatModel"),
("dinov2", "Dinov2Model"),
@@ -472,6 +473,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
+ ("dia", "DiaForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
@@ -1059,6 +1061,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
+ ("dia", "DiaForConditionalGeneration"),
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
("moonshine", "MoonshineForConditionalGeneration"),
@@ -1629,6 +1632,12 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
]
)
+MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
+ [
+ ("dac", "DacModel"),
+ ]
+)
+
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
@@ -1737,6 +1746,8 @@ MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
+MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
+
class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
@@ -2034,6 +2045,15 @@ class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
+class AutoModelForAudioTokenization(_BaseAutoModelClass):
+ _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
+
+
+AutoModelForAudioTokenization = auto_class_update(
+ AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
+)
+
+
class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
@@ -2059,6 +2079,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead):
__all__ = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
"MODEL_FOR_AUDIO_XVECTOR_MAPPING",
"MODEL_FOR_BACKBONE_MAPPING",
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
@@ -2106,6 +2127,7 @@ __all__ = [
"AutoBackbone",
"AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
+ "AutoModelForAudioTokenization",
"AutoModelForAudioXVector",
"AutoModelForCausalLM",
"AutoModelForCTC",
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 372c0b249b..bccfe3e6d5 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -61,6 +61,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("colqwen2", "ColQwen2Processor"),
+ ("dia", "DiaProcessor"),
("emu3", "Emu3Processor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 50a1a2732c..0456e1945c 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -177,6 +177,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
+ ("dia", ("DiaTokenizer", None)),
(
"diffllama",
(
diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py
index b3bca5b63e..191e7af89e 100644
--- a/src/transformers/models/dac/modeling_dac.py
+++ b/src/transformers/models/dac/modeling_dac.py
@@ -23,7 +23,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import PreTrainedAudioTokenizerBase
from ...utils import ModelOutput, auto_docstring
from .configuration_dac import DacConfig
@@ -471,7 +471,7 @@ class DacEncoder(nn.Module):
@auto_docstring
-class DacPreTrainedModel(PreTrainedModel):
+class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
config_class = DacConfig
base_model_prefix = "dac"
main_input_name = "input_values"
diff --git a/src/transformers/models/dia/__init__.py b/src/transformers/models/dia/__init__.py
new file mode 100644
index 0000000000..d738fbc087
--- /dev/null
+++ b/src/transformers/models/dia/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dia import *
+ from .feature_extraction_dia import *
+ from .generation_dia import *
+ from .modeling_dia import *
+ from .processing_dia import *
+ from .tokenization_dia import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/dia/configuration_dia.py b/src/transformers/models/dia/configuration_dia.py
new file mode 100644
index 0000000000..90ace73b3c
--- /dev/null
+++ b/src/transformers/models/dia/configuration_dia.py
@@ -0,0 +1,376 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dia model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
+ encoder according to the specified arguments, defining the encoder architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ Number of key and value heads for each attention layer in the Transformer encoder.
+ head_dim (`int`, *optional*, defaults to 128):
+ Dimensionality of the attention head.
+ intermediate_size (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the normalization layers.
+ vocab_size (`int`, *optional*, defaults to 256):
+ Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`DiaModel`].
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"swish"` and `"gelu_new"` are supported.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "dia_encoder"
+
+ def __init__(
+ self,
+ max_position_embeddings: int = 1024,
+ num_hidden_layers: int = 12,
+ hidden_size: int = 1024,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 16,
+ head_dim: int = 128,
+ intermediate_size: int = 4096,
+ norm_eps: float = 1e-5,
+ vocab_size: int = 256,
+ hidden_act: str = "silu",
+ rope_theta: float = 10000.0,
+ rope_scaling: Optional[dict] = None,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ self.max_position_embeddings = max_position_embeddings
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.norm_eps = norm_eps
+ self.vocab_size = vocab_size
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+ self.initializer_range = initializer_range
+ super().__init__(**kwargs)
+
+
+class DiaDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
+ decoder according to the specified arguments, defining the decoder architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ max_position_embeddings (`int`, *optional*, defaults to 3072):
+ The maximum sequence length that this model might ever be used with.
+ num_hidden_layers (`int`, *optional*, defaults to 18):
+ Number of hidden layers in the Transformer decoder.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the decoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 4):
+ Number of key and value heads for each attention layer in the Transformer decoder.
+ head_dim (`int`, *optional*, defaults to 128):
+ Dimensionality of the attention head.
+ cross_num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each cross-attention layer in the Transformer decoder.
+ cross_head_dim (`int`, *optional*, defaults to 128):
+ Dimensionality of the cross-attention head.
+ cross_num_key_value_heads (`int`, *optional*, defaults to 16):
+ Number of key and value heads for each cross-attention layer in the Transformer decoder.
+ cross_hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the cross-attention layers.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the normalization layers.
+ vocab_size (`int`, *optional*, defaults to 1028):
+ Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`DiaModel`].
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
+ `"swish"` and `"gelu_new"` are supported.
+ num_channels (`int`, *optional*, defaults to 9):
+ Number of channels for the Dia decoder.
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Indicating that this model is part of an encoder-decoder architecture.
+ """
+
+ model_type = "dia_decoder"
+
+ def __init__(
+ self,
+ max_position_embeddings: int = 3072,
+ num_hidden_layers: int = 18,
+ hidden_size: int = 2048,
+ intermediate_size: int = 8192,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 4,
+ head_dim: int = 128,
+ cross_num_attention_heads: int = 16,
+ cross_head_dim: int = 128,
+ cross_num_key_value_heads: int = 16,
+ cross_hidden_size: int = 1024,
+ norm_eps: float = 1e-5,
+ vocab_size: int = 1028,
+ hidden_act: str = "silu",
+ num_channels: int = 9,
+ rope_theta: float = 10000.0,
+ rope_scaling: Optional[dict] = None,
+ initializer_range: float = 0.02,
+ use_cache: bool = True,
+ is_encoder_decoder: bool = True,
+ **kwargs,
+ ):
+ self.max_position_embeddings = max_position_embeddings
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.cross_num_key_value_heads = cross_num_key_value_heads
+ self.cross_num_attention_heads = cross_num_attention_heads
+ self.cross_head_dim = cross_head_dim
+ self.cross_hidden_size = cross_hidden_size
+ self.norm_eps = norm_eps
+ self.vocab_size = vocab_size
+ self.hidden_act = hidden_act
+ self.num_channels = num_channels
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+
+class DiaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
+ Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the
+ [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ encoder_config (`DiaEncoderConfig`, *optional*):
+ Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
+ decoder_config (`DiaDecoderConfig`, *optional*):
+ Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the normalization layers.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Indicating that this model uses an encoder-decoder architecture.
+ pad_token_id (`int`, *optional*, defaults to 1025):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1024):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 1026):
+ Beginning of stream token id.
+ delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
+ The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+
+ Example:
+
+ ```python
+ >>> from transformers import DiaConfig, DiaModel
+
+ >>> # Initializing a DiaConfig with default values
+ >>> configuration = DiaConfig()
+
+ >>> # Initializing a DiaModel (with random weights) from the configuration
+ >>> model = DiaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "dia"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
+
+ def __init__(
+ self,
+ encoder_config: Optional[DiaEncoderConfig] = None,
+ decoder_config: Optional[DiaDecoderConfig] = None,
+ norm_eps: float = 1e-5,
+ is_encoder_decoder: bool = True,
+ pad_token_id: int = 1025,
+ eos_token_id: int = 1024,
+ bos_token_id: int = 1026,
+ delay_pattern: Optional[list[int]] = None,
+ initializer_range: float = 0.02,
+ use_cache: bool = True,
+ **kwargs,
+ ):
+ if isinstance(encoder_config, dict):
+ encoder_config = DiaEncoderConfig(**encoder_config)
+ if isinstance(decoder_config, dict):
+ decoder_config = DiaDecoderConfig(**decoder_config)
+ self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
+ self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
+ self.norm_eps = norm_eps
+ self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+
+ assert self.decoder_config.num_channels == len(self.delay_pattern), (
+ "Number of channels must match delay pattern length."
+ )
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ bos_token_id=bos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+ def get_text_config(self, decoder=False):
+ """Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
+ return self.decoder_config
+
+
+__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]
diff --git a/src/transformers/models/dia/convert_dia_to_hf.py b/src/transformers/models/dia/convert_dia_to_hf.py
new file mode 100644
index 0000000000..3a33860f6b
--- /dev/null
+++ b/src/transformers/models/dia/convert_dia_to_hf.py
@@ -0,0 +1,199 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Converts a Dia model in Nari Labs format to Hugging Face format."""
+
+import argparse
+import os
+import re
+
+import torch
+from huggingface_hub import snapshot_download
+from safetensors.torch import load_file
+
+from transformers import (
+ DacModel,
+ DiaConfig,
+ DiaFeatureExtractor,
+ DiaForConditionalGeneration,
+ DiaProcessor,
+ DiaTokenizer,
+ GenerationConfig,
+)
+from transformers.utils.import_utils import _is_package_available
+
+
+# Provide just the list of layer keys you want to fix
+shape_mappings = [
+ "encoder.layers.*.mlp.gate_up_proj.weight",
+ "encoder.layers.*.mlp.down_proj.weight",
+ "encoder.layers.*.self_attention.q_proj.weight",
+ "encoder.layers.*.self_attention.k_proj.weight",
+ "encoder.layers.*.self_attention.v_proj.weight",
+ "encoder.layers.*.self_attention.o_proj.weight",
+ "decoder.layers.*.mlp.gate_up_proj.weight",
+ "decoder.layers.*.mlp.down_proj.weight",
+ "decoder.layers.*.self_attention.q_proj.weight",
+ "decoder.layers.*.self_attention.k_proj.weight",
+ "decoder.layers.*.self_attention.v_proj.weight",
+ "decoder.layers.*.self_attention.o_proj.weight",
+ "decoder.layers.*.cross_attention.q_proj.weight",
+ "decoder.layers.*.cross_attention.k_proj.weight",
+ "decoder.layers.*.cross_attention.v_proj.weight",
+ "decoder.layers.*.cross_attention.o_proj.weight",
+ "decoder.logits_dense.weight",
+]
+
+# Provide renamings here
+rename_mapping = {
+ "mlp.wo": "mlp.down_proj",
+ "mlp.wi_fused": "mlp.gate_up_proj",
+}
+
+
+def get_generation_config(config):
+ model_generation_config = GenerationConfig.from_model_config(config)
+ model_generation_config._from_model_config = False
+ model_generation_config.do_sample = True
+ model_generation_config.top_k = 45
+ model_generation_config.top_p = 0.95
+ model_generation_config.temperature = 1.2
+ model_generation_config.guidance_scale = 3.0
+ model_generation_config.max_length = 3072 # Decoder max length
+
+ return model_generation_config
+
+
+def convert_dia_model_to_hf(checkpoint_path, verbose=False):
+ """
+ Converts a Dia model in Nari Labs format to Hugging Face format.
+ Args:
+ checkpoint_path (`str`):
+ Path to the downloaded checkpoints.
+ verbose (`bool`, *optional*)
+ Whether to print information during conversion.
+ """
+ # Download from HF Hub if checkpoint_path is None
+ checkpoint_path = snapshot_download(repo_id=checkpoint_path, allow_patterns=["*.pth", "*.safetensors"])
+ print(f"Downloaded checkpoint from Hugging Face Hub: {checkpoint_path}")
+
+ # Initialize base model with default config == 1.6B model
+ with torch.device("meta"):
+ hf_model = DiaForConditionalGeneration(config=DiaConfig())
+ hf_model_dict = hf_model.state_dict()
+ hf_model_keys = hf_model_dict.keys()
+
+ # Iterate through dir to catch all respective files - prefers safetensors but allows pt
+ files = os.listdir(checkpoint_path)
+ for file in files:
+ if file.endswith(".safetensors"):
+ load_function = load_file
+ elif file.endswith(".pth"):
+ load_function = torch.load
+ checkpoint_path = os.path.join(checkpoint_path, files[0])
+ nari_state_dict = load_function(checkpoint_path, "cpu")
+
+ # Conversion starts here
+ converted_state_dict = {}
+ embeddings = {}
+ for key, tensor in nari_state_dict.items():
+ # add prefix
+ key = "model." + key
+
+ # rename some weights
+ for original, rename in rename_mapping.items():
+ if original in key:
+ key = re.sub(original, rename, key)
+
+ # decoder multi channel
+ if "embeddings" in key:
+ embeddings_key = key.rsplit(".", 2)[0] + ".embed.weight"
+ if embeddings_key in embeddings:
+ embeddings[embeddings_key] += [tensor]
+ else:
+ embeddings[embeddings_key] = [tensor]
+ continue
+ elif re.sub(r"\d+", "*", key).removeprefix("model.") in shape_mappings:
+ # add exception to the head
+ if "logits_dense" in key:
+ key = re.sub("decoder.logits_dense", "logits_dense", key).removeprefix("model.")
+
+ # dense general
+ if key in hf_model_keys:
+ tensor_shape = tensor.shape
+ target_shape = hf_model_dict[key].shape
+ try:
+ tensor = tensor.reshape(target_shape[1], target_shape[0]).T
+ if verbose:
+ print(f"{key}: transpose reshaped from {tensor_shape} to {target_shape}")
+ except Exception as e:
+ print(f"WARNING: Could not reshape {key}: {e}")
+
+ converted_state_dict[key] = tensor
+
+ # Combining the embeddings as last step
+ embeddings = {k: torch.cat(v, dim=0) for k, v in embeddings.items()}
+ converted_state_dict.update(embeddings)
+
+ # Load converted weights into HF model
+ hf_model.load_state_dict(converted_state_dict, assign=True)
+
+ # Overwrite generation config
+ hf_model.generation_config = get_generation_config(DiaConfig())
+
+ return hf_model
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # # Required parameters
+ parser.add_argument(
+ "--checkpoint_path", type=str, default="nari-labs/Dia-1.6B", help="Path to the downloaded checkpoints"
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default="AntonV/Dia-1.6B", type=str, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--convert_preprocessor",
+ type=bool,
+ default=True,
+ help="Whether or not the preprocessor (tokenizer + feature extractor) should be converted along with the model.",
+ )
+ parser.add_argument(
+ "--verbose",
+ type=bool,
+ default=True,
+ help="Whether or not to log information during conversion.",
+ )
+ args = parser.parse_args()
+
+ model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose)
+ if args.convert_preprocessor:
+ try:
+ if not _is_package_available("tiktoken"):
+ raise ModuleNotFoundError(
+ """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
+ )
+ except Exception as e:
+ print(e)
+ else:
+ processor = DiaProcessor(
+ DiaFeatureExtractor(sampling_rate=44100, hop_length=512),
+ DiaTokenizer(),
+ DacModel.from_pretrained("descript/dac_44khz"),
+ )
+ processor.save_pretrained(args.pytorch_dump_folder_path)
+
+ model.save_pretrained(args.pytorch_dump_folder_path)
+ print(f"Saved converted checkpoint to {args.pytorch_dump_folder_path}")
diff --git a/src/transformers/models/dia/feature_extraction_dia.py b/src/transformers/models/dia/feature_extraction_dia.py
new file mode 100644
index 0000000000..0d03ceff37
--- /dev/null
+++ b/src/transformers/models/dia/feature_extraction_dia.py
@@ -0,0 +1,183 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Dia"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs an Dia feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 1):
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ The value that is used for padding.
+ hop_length (`int`, *optional*, defaults to 512):
+ Overlap length between successive windows.
+ """
+
+ model_input_names = ["input_values", "n_quantizers"]
+
+ def __init__(
+ self,
+ feature_size: int = 1,
+ sampling_rate: int = 16000,
+ padding_value: float = 0.0,
+ hop_length: int = 512,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+ self.hop_length = hop_length
+
+ def __call__(
+ self,
+ raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
+ truncation: Optional[bool] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s).
+
+ Args:
+ raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
+ (`feature_size = 2`).
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, *optional*, defaults to `False`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ """
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ if padding and truncation:
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
+ elif padding is None:
+ # by default let's pad the inputs
+ padding = True
+
+ is_batched = bool(
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
+ raw_audio = raw_audio.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_audio = [np.asarray(raw_audio).T]
+
+ # convert stereo to mono if necessary, unique to Dia
+ for idx, example in enumerate(raw_audio):
+ if self.feature_size == 2 and example.ndim == 2:
+ raw_audio[idx] = np.mean(example, -1)
+
+ # verify inputs are valid
+ for idx, example in enumerate(raw_audio):
+ if example.ndim > 2:
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
+ if self.feature_size == 1 and example.ndim != 1:
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
+ if self.feature_size == 2 and example.ndim != 1: # note the conversion before
+ raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
+
+ input_values = BatchFeature({"input_values": raw_audio})
+
+ # temporarily treat it as if we were mono as we also convert stereo to mono
+ origingal_feature_size = self.feature_size
+ self.feature_size = 1
+
+ # normal padding on batch
+ padded_inputs = self.pad(
+ input_values,
+ max_length=max_length,
+ truncation=truncation,
+ padding=padding,
+ return_attention_mask=True,
+ pad_to_multiple_of=self.hop_length,
+ )
+ padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
+
+ input_values = []
+ for example in padded_inputs.pop("input_values"):
+ if self.feature_size == 1:
+ example = example[..., None]
+ input_values.append(example.T)
+
+ padded_inputs["input_values"] = input_values
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ # rewrite back to original feature size
+ self.feature_size = origingal_feature_size
+
+ return padded_inputs
+
+
+__all__ = ["DiaFeatureExtractor"]
diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py
new file mode 100644
index 0000000000..0ca5998bf2
--- /dev/null
+++ b/src/transformers/models/dia/generation_dia.py
@@ -0,0 +1,464 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+
+from ...generation.logits_process import (
+ DiaClassifierFreeGuidanceLogitsProcessor,
+ DiaEOSChannelFilterLogitsProcessor,
+ DiaEOSDelayPatternLogitsProcessor,
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+)
+from ...generation.stopping_criteria import StoppingCriteriaList
+from ...generation.streamers import BaseStreamer
+from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaGenerationMixin(GenerationMixin):
+ # Indicates CFG which needs preparation to be properly handled by repeats
+ _uses_cfg = None
+
+ def _get_logits_processor(
+ self,
+ generation_config: GenerationConfig,
+ input_ids_seq_length: Optional[int] = None,
+ encoder_input_ids: torch.LongTensor = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ device: Optional[str] = None,
+ model_kwargs: Optional[dict[str, Any]] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ ) -> LogitsProcessorList:
+ # Need either custom order or custom processor instead
+ # (Temporarily disabling those for the super function)
+ original_guidance_scale = generation_config.guidance_scale
+ original_temperature = generation_config.temperature
+ generation_config.guidance_scale = None
+ generation_config.temperature = None
+
+ # Get base processors and those we can integrate easily
+ custom_processors = LogitsProcessorList()
+
+ if original_temperature is not None and original_temperature != 1.0:
+ custom_processors.append(TemperatureLogitsWarper(original_temperature))
+
+ custom_processors.append(
+ DiaEOSChannelFilterLogitsProcessor(
+ num_channels=len(self.config.delay_pattern),
+ eos_token_id=self.config.eos_token_id,
+ )
+ )
+
+ merged_processors = super()._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=encoder_input_ids,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=custom_processors,
+ device=device,
+ model_kwargs=model_kwargs,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ # Custom processors we need at specific positions
+ if original_guidance_scale is not None and original_guidance_scale != 1:
+ cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
+ guidance_scale=original_guidance_scale,
+ guidance_top_k=generation_config.top_k,
+ )
+ merged_processors.insert(0, cfg_processor)
+
+ merged_processors.append(
+ DiaEOSDelayPatternLogitsProcessor(
+ delay_pattern=self.config.delay_pattern,
+ eos_token_id=self.config.eos_token_id,
+ max_generation_len=generation_config.max_length,
+ device=device,
+ )
+ )
+
+ # Enable temporarily disabled values back
+ generation_config.guidance_scale = original_guidance_scale
+ generation_config.temperature = original_temperature
+
+ return merged_processors
+
+ def _prepare_generation_config(
+ self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
+ ) -> tuple[GenerationConfig, dict]:
+ generation_config, model_kwargs = super()._prepare_generation_config(
+ generation_config, use_model_defaults, **kwargs
+ )
+
+ # We allow generation up to max length + max delay pattern
+ # (will revert back to max length after generation)
+ generation_config.max_length += max(self.config.delay_pattern)
+
+ # Internal flag to indicate CFG that needs to prepare unconditioned input
+ self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
+
+ return generation_config, model_kwargs
+
+ def _prepare_model_inputs(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[torch.Tensor] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
+ inputs, input_name, model_kwargs = super()._prepare_model_inputs(
+ inputs=inputs,
+ bos_token_id=bos_token_id,
+ model_kwargs=model_kwargs,
+ )
+
+ # If CFG is requested we fill in the unconditioned parts
+ if self._uses_cfg:
+ unconditioned_inputs = torch.zeros_like(inputs)
+ inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
+
+ if model_kwargs.get("attention_mask", None) is not None:
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
+
+ return inputs, input_name, model_kwargs
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: dict[str, torch.Tensor],
+ decoder_start_token_id: torch.Tensor,
+ device: Optional[torch.device] = None,
+ ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+ # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
+ decoder_input_ids = decoder_attention_mask = None
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
+
+ # We allow generating without preparation (no proper delay) but discourage it
+ if decoder_input_ids is None or decoder_attention_mask is None:
+ logger.warning_once(
+ "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
+ f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
+ f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
+ )
+
+ num_channels = self.config.decoder_config.num_channels
+ real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
+
+ if decoder_input_ids is None:
+ decoder_input_ids = torch.full(
+ (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
+ )
+
+ decoder_attention_mask = torch.ones(
+ size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
+ )
+
+ # 2. Determine the valid input and what works as mask within the input
+ delay_mask = decoder_input_ids.long()
+ valid_input_size = (
+ decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
+ )
+ decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
+ decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
+
+ # 3. Overwrite into model kwargs
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+ model_kwargs["decoder_delay_mask"] = delay_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ encoder_outputs=None, # Using this to easily get the batch size
+ decoder_delay_mask=None,
+ **kwargs,
+ ):
+ # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
+ batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
+ input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
+
+ # Base method handles most things except CFG and the delay pattern mask
+ model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
+
+ # Post processing for CFG and overwriting via delay pattern mask
+ # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
+ model_inputs["decoder_input_ids"] = self.apply_delay_mask(
+ input_ids, self.config.pad_token_id, decoder_delay_mask
+ )
+
+ # Depending on cache usage we need to pass all or just one
+ if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
+ model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
+
+ # Be compile friendly
+ model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
+
+ # 2. Apply CFG duplication if needed
+ if self._uses_cfg:
+ for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
+ if model_inputs.get(key, None) is not None:
+ # double first dimension and keep everything else the same
+ repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
+ model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
+
+ return model_inputs
+
+ @staticmethod
+ def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
+ if delay_mask is None:
+ return input_ids
+
+ mask_len = min(input_ids.shape[1], delay_mask.shape[1])
+ valid_mask = delay_mask[:, :mask_len, :]
+ valid_input = input_ids[:, :mask_len, :]
+
+ # Overwrite the respective parts of the input
+ input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
+
+ return input_ids
+
+ def _main_generate_loop(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ use_model_defaults: Optional[bool] = None,
+ custom_generate: Optional[str] = None,
+ **kwargs,
+ ):
+ # ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
+ assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
+
+ generation_config, model_kwargs = self._prepare_generation_config(
+ generation_config, use_model_defaults, **kwargs
+ )
+ self._validate_model_kwargs(model_kwargs.copy())
+ self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
+
+ # 2. Set generation parameters if not already defined
+ if synced_gpus is None:
+ synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
+
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ # 3. Define model inputs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ device = inputs_tensor.device
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
+
+ # 4. Define other model kwargs
+ if "encoder_outputs" not in model_kwargs:
+ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name, generation_config
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
+ device=inputs_tensor.device,
+ )
+
+ if generation_config.token_healing:
+ input_ids = self.heal_tokens(input_ids, tokenizer)
+
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ # NOTE: incorrect `input_ids.shape[1]` previously
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=inputs_tensor,
+ input_ids_length=input_ids_length,
+ )
+
+ # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
+ # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
+ # dynamically overrides this value as it can need more than the last token logits
+ if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
+ model_kwargs["logits_to_keep"] = 1
+
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+ # 7. Prepare the cache.
+ # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
+ # - different models have a different cache name expected by the model (default = "past_key_values")
+ # - `max_length`, prepared above, is used to determine the maximum cache length
+ max_cache_length = generation_config.max_length - 1
+ if (
+ inputs_tensor.shape[1] != input_ids_length
+ and model_input_name == "inputs_embeds"
+ and not self.config.is_encoder_decoder
+ ):
+ max_cache_length += inputs_tensor.shape[1]
+ self._prepare_cache_for_generation(
+ generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
+ )
+
+ # 8. determine generation mode
+ generation_mode = generation_config.get_generation_mode(assistant_model)
+
+ if streamer is not None and (generation_config.num_beams > 1):
+ raise ValueError(
+ "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
+ )
+
+ # 9. prepare logits processors and stopping criteria
+ prepared_logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ device=inputs_tensor.device,
+ model_kwargs=model_kwargs,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+ prepared_stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
+ )
+
+ # Set model_kwargs `use_cache` so we can use it later in forward runs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ # ******************* taken from main generate function up to calling the different methods *******************
+
+ # Prepare inner 2D logic in generation loop
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1])
+
+ # 10. go into different generation modes
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # 11. expand input_ids with `num_return_sequences` additional sequences per batch
+ if generation_config.num_return_sequences > 1:
+ raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
+
+ # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
+ return self._sample(
+ input_ids,
+ logits_processor=prepared_logits_processor,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ use_model_defaults: Optional[bool] = None,
+ custom_generate: Optional[str] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ # We expect the initial input ids to be the complete mask (delayed input)
+ delay_mask = kwargs.get("decoder_input_ids", None)
+ if delay_mask is not None:
+ delay_mask = delay_mask.clone()
+
+ output = self._main_generate_loop(
+ inputs=inputs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ synced_gpus=synced_gpus,
+ assistant_model=assistant_model,
+ streamer=streamer,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ use_model_defaults=use_model_defaults,
+ custom_generate=custom_generate,
+ **kwargs,
+ )
+
+ return_dict_in_generate = not isinstance(output, torch.Tensor)
+
+ if return_dict_in_generate:
+ output_sequences = output.sequences
+ else:
+ output_sequences = output
+
+ # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
+ num_channels = self.config.decoder_config.num_channels
+ bsz = output_sequences.shape[0] // num_channels
+ output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
+
+ # Apply delay mask
+ output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
+
+ if return_dict_in_generate:
+ output.sequences = output_sequences
+ else:
+ output = output_sequences
+
+ return output
diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py
new file mode 100644
index 0000000000..19cac3e8c3
--- /dev/null
+++ b/src/transformers/models/dia/modeling_dia.py
@@ -0,0 +1,963 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dia/modular_dia.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dia.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
+from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from .generation_dia import DiaGenerationMixin
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class DiaPreTrainedModel(PreTrainedModel):
+ config_class = DiaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+ main_input_name = "input_ids"
+ _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, DiaRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+class DiaMultiChannelEmbedding(nn.Module):
+ """In order to efficiently compute the audio embedding from the 9 different channels,
+ we vectorize the embedding process by using a single embedding layer and an offset.
+ Example:
+ - num_embeds = 4
+ - vocab_size = 8
+ - num_channels = 3
+ We would have offsets = [0, 8, 16]
+ If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
+ then tokens = audio_codes + offsets
+ = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
+ This allows us to use a single embedding layer for all channels.
+ """
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__()
+ self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+ offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
+ self.register_buffer("offsets", offsets, persistent=False)
+
+ def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
+ tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
+ embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
+ return embeds.sum(dim=2)
+
+
+class DiaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DiaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DiaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class DiaRotaryEmbedding(nn.Module):
+ def __init__(self, config: DiaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class DiaSelfAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = is_causal
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiaCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.cross_hidden_size = config.cross_hidden_size
+ self.num_heads = self.config.cross_num_attention_heads
+ self.num_key_value_heads = self.config.cross_num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = config.cross_head_dim
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
+ if past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
+ value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
+ else:
+ key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+ value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all states to the cache
+ key_states, value_states = past_key_values.cross_attention_cache.update(
+ key_states,
+ value_states,
+ self.layer_idx,
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiaEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaEncoderConfig, layer_idx: int):
+ super().__init__()
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
+ self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.post_sa_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights
+
+
+class DiaEncoder(DiaPreTrainedModel):
+ def __init__(self, config: DiaEncoderConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[BaseModelOutput, tuple]:
+ hidden_states = self.embedding(input_ids)
+
+ # RoPE
+ # Note: We expect right padding and hence always generate
+ # the position ids on the fly to reduce preparation overhead
+ position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states += (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+class DiaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
+ self.cross_attention = DiaCrossAttention(config, layer_idx)
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ self_attn_cache = past_key_values
+ if isinstance(self_attn_cache, EncoderDecoderCache):
+ self_attn_cache = self_attn_cache.self_attention_cache
+
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings,
+ attention_mask,
+ # Needs to be an arg in order to function properly
+ # on inplace operations to be carried (e.g. compile)
+ self_attn_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.pre_ca_norm(hidden_states)
+ cross_states, cross_attn_weights = self.cross_attention(
+ normed_states,
+ encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ **kwargs,
+ )
+ hidden_states = residual + cross_states
+
+ residual = hidden_states
+ normed_states = self.pre_mlp_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights, cross_attn_weights
+
+
+class DiaDecoder(DiaPreTrainedModel):
+ """Transformer Decoder Stack using DenseGeneral."""
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__(config)
+ self.num_channels = config.num_channels
+ self.vocab_size = config.vocab_size
+ self.embeddings = DiaMultiChannelEmbedding(config)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+ self.layers = nn.ModuleList(
+ [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
+ The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+
+ batch_size, seq_length = input_ids.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
+ )
+ if position_ids is None:
+ position_ids = cache_position[None, :]
+
+ # RoPE
+ hidden_states = self.embeddings(input_ids)
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ if attention_mask is None and not is_torchdynamo_compiling():
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ hidden_states.shape[:2],
+ hidden_states,
+ )
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(
+ hidden_states,
+ position_embeddings,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns = all_self_attns + (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Dia model outputting raw hidden-states without any specific head on top.
+ """
+)
+class DiaModel(DiaPreTrainedModel):
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.encoder = DiaEncoder(config.encoder_config)
+ self.decoder = DiaDecoder(config.decoder_config)
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ """
+
+ if input_ids is None and encoder_outputs is None:
+ raise ValueError(
+ "You should either provide text ids or the cached text encodings. Neither has been found."
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if self.is_gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
+ elif not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # On default we initialize the decoder with bos tokens if nothing has been provided
+ bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
+ if decoder_input_ids is None:
+ decoder_input_ids = torch.full(
+ size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
+ )
+ # Ensure 3D
+ if decoder_input_ids.ndim == 2:
+ decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ position_ids=decoder_position_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs[0],
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
+ """
+)
+class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
+ base_model_prefix = "model"
+
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = DiaModel(config)
+
+ self.num_channels = config.decoder_config.num_channels
+ self.vocab_size = config.decoder_config.vocab_size
+ self.logits_dense = nn.Linear(
+ config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
+ )
+ self.loss_type = "ForMaskedLM"
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in
+ `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
+ are ignored (masked).
+ """
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_position_ids=decoder_position_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ last_hidden_state = outputs[0]
+ batch_size = last_hidden_state.shape[0]
+ # 3D <-> 2D makes it necessary to prioritize channel dim
+ audio_logits = (
+ self.logits_dense(last_hidden_state)
+ .view((batch_size, -1, self.num_channels, self.vocab_size))
+ .transpose(1, 2)
+ .contiguous()
+ .view(batch_size * self.num_channels, -1, self.vocab_size)
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=audio_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py
new file mode 100644
index 0000000000..fe437fde84
--- /dev/null
+++ b/src/transformers/models/dia/modular_dia.py
@@ -0,0 +1,789 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Dia model."""
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import DynamicCache, EncoderDecoderCache
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
+from ..llama.modeling_llama import (
+ LlamaAttention,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ eager_attention_forward,
+)
+from ..phi3.modeling_phi3 import Phi3MLP
+from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from .generation_dia import DiaGenerationMixin
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class DiaPreTrainedModel(PreTrainedModel):
+ config_class = DiaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+ main_input_name = "input_ids"
+ _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, DiaRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+class DiaMultiChannelEmbedding(nn.Module):
+ """In order to efficiently compute the audio embedding from the 9 different channels,
+ we vectorize the embedding process by using a single embedding layer and an offset.
+ Example:
+ - num_embeds = 4
+ - vocab_size = 8
+ - num_channels = 3
+ We would have offsets = [0, 8, 16]
+ If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
+ then tokens = audio_codes + offsets
+ = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
+ This allows us to use a single embedding layer for all channels.
+ """
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__()
+ self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+ offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
+ self.register_buffer("offsets", offsets, persistent=False)
+
+ def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
+ tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
+ embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
+ return embeds.sum(dim=2)
+
+
+class DiaMLP(Phi3MLP):
+ pass
+
+
+class DiaRMSNorm(LlamaRMSNorm):
+ pass
+
+
+class DiaRotaryEmbedding(LlamaRotaryEmbedding):
+ pass
+
+
+class DiaSelfAttention(LlamaAttention, nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
+ nn.Module.__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.num_heads = self.config.num_attention_heads
+ self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = is_causal
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+
+class DiaCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.cross_hidden_size = config.cross_hidden_size
+ self.num_heads = self.config.cross_num_attention_heads
+ self.num_key_value_heads = self.config.cross_num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.head_dim = config.cross_head_dim
+ self.scaling = 1
+ self.attention_dropout = 0.0
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+ cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
+ if past_key_values is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
+ value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
+ else:
+ key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+ value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+
+ if past_key_values is not None:
+ # save all states to the cache
+ key_states, value_states = past_key_values.cross_attention_cache.update(
+ key_states,
+ value_states,
+ self.layer_idx,
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ past_key_values.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class DiaEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaEncoderConfig, layer_idx: int):
+ super().__init__()
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
+ self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.post_sa_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights
+
+
+class DiaEncoder(DiaPreTrainedModel):
+ def __init__(self, config: DiaEncoderConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[BaseModelOutput, tuple]:
+ hidden_states = self.embedding(input_ids)
+
+ # RoPE
+ # Note: We expect right padding and hence always generate
+ # the position ids on the fly to reduce preparation overhead
+ position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states += (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+
+class DiaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
+ self.cross_attention = DiaCrossAttention(config, layer_idx)
+ self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+ self.mlp = DiaMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ self_attn_cache = past_key_values
+ if isinstance(self_attn_cache, EncoderDecoderCache):
+ self_attn_cache = self_attn_cache.self_attention_cache
+
+ residual = hidden_states
+ normed_states = self.pre_sa_norm(hidden_states)
+ self_attn_output, self_attn_weights = self.self_attention(
+ normed_states,
+ position_embeddings,
+ attention_mask,
+ # Needs to be an arg in order to function properly
+ # on inplace operations to be carried (e.g. compile)
+ self_attn_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + self_attn_output
+
+ residual = hidden_states
+ normed_states = self.pre_ca_norm(hidden_states)
+ cross_states, cross_attn_weights = self.cross_attention(
+ normed_states,
+ encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ **kwargs,
+ )
+ hidden_states = residual + cross_states
+
+ residual = hidden_states
+ normed_states = self.pre_mlp_norm(hidden_states)
+ mlp_out = self.mlp(normed_states)
+ hidden_states = residual + mlp_out
+
+ return hidden_states, self_attn_weights, cross_attn_weights
+
+
+class DiaDecoder(DiaPreTrainedModel):
+ """Transformer Decoder Stack using DenseGeneral."""
+
+ def __init__(self, config: DiaDecoderConfig):
+ super().__init__(config)
+ self.num_channels = config.num_channels
+ self.vocab_size = config.vocab_size
+ self.embeddings = DiaMultiChannelEmbedding(config)
+ self.rotary_embeddings = DiaRotaryEmbedding(config)
+ self.layers = nn.ModuleList(
+ [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
+ r"""
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
+ The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
+
+ [What are input IDs?](../glossary#input-ids)
+ """
+
+ batch_size, seq_length = input_ids.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
+ )
+ if position_ids is None:
+ position_ids = cache_position[None, :]
+
+ # RoPE
+ hidden_states = self.embeddings(input_ids)
+ position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+ if attention_mask is None and not is_torchdynamo_compiling():
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
+
+ attention_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=hidden_states,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ hidden_states.shape[:2],
+ hidden_states,
+ )
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(
+ hidden_states,
+ position_embeddings,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns = all_self_attns + (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The bare Dia model outputting raw hidden-states without any specific head on top.
+ """
+)
+class DiaModel(DiaPreTrainedModel):
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.encoder = DiaEncoder(config.encoder_config)
+ self.decoder = DiaDecoder(config.decoder_config)
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ """
+
+ if input_ids is None and encoder_outputs is None:
+ raise ValueError(
+ "You should either provide text ids or the cached text encodings. Neither has been found."
+ )
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if self.is_gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache and past_key_values is None:
+ past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
+ elif not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # On default we initialize the decoder with bos tokens if nothing has been provided
+ bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
+ if decoder_input_ids is None:
+ decoder_input_ids = torch.full(
+ size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
+ )
+ # Ensure 3D
+ if decoder_input_ids.ndim == 2:
+ decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ position_ids=decoder_position_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs[0],
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
+ """
+)
+class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
+ base_model_prefix = "model"
+
+ def __init__(self, config: DiaConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = DiaModel(config)
+
+ self.num_channels = config.decoder_config.num_channels
+ self.vocab_size = config.decoder_config.vocab_size
+ self.logits_dense = nn.Linear(
+ config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
+ )
+ self.loss_type = "ForMaskedLM"
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_position_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+ past_key_values: Optional[EncoderDecoderCache] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[tuple, Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+ or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+ 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+ the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+ tened audio logits which are used to calculate the loss.
+
+ 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+ Dia to calculate embeddings and subsequent steps more efficiently.
+
+ If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+ `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+ [`DiaProcessor.__call__`] for more details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+ [What are position IDs?](../glossary#position-ids)
+ labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in
+ `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
+ are ignored (masked).
+ """
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_position_ids=decoder_position_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ last_hidden_state = outputs[0]
+ batch_size = last_hidden_state.shape[0]
+ # 3D <-> 2D makes it necessary to prioritize channel dim
+ audio_logits = (
+ self.logits_dense(last_hidden_state)
+ .view((batch_size, -1, self.num_channels, self.vocab_size))
+ .transpose(1, 2)
+ .contiguous()
+ .view(batch_size * self.num_channels, -1, self.vocab_size)
+ )
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=audio_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
diff --git a/src/transformers/models/dia/processing_dia.py b/src/transformers/models/dia/processing_dia.py
new file mode 100644
index 0000000000..e50ef5de67
--- /dev/null
+++ b/src/transformers/models/dia/processing_dia.py
@@ -0,0 +1,484 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Processor class for Dia"""
+
+import math
+from pathlib import Path
+from typing import Optional, Union
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...feature_extraction_utils import BatchFeature
+from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...utils import is_soundfile_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+if is_soundfile_available():
+ import soundfile as sf
+
+
+class DiaAudioKwargs(AudioKwargs, total=False):
+ bos_token_id: int
+ eos_token_id: int
+ pad_token_id: int
+ delay_pattern: list[int]
+ generation: bool
+
+
+class DiaProcessorKwargs(ProcessingKwargs, total=False):
+ audio_kwargs: DiaAudioKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": True,
+ "padding_side": "right",
+ "add_special_tokens": False,
+ },
+ "audio_kwargs": {
+ "eos_token_id": 1024,
+ "pad_token_id": 1025,
+ "bos_token_id": 1026,
+ "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
+ "generation": True,
+ "sampling_rate": 44100,
+ },
+ "common_kwargs": {"return_tensors": "pt"},
+ }
+
+
+class DiaProcessor(ProcessorMixin):
+ r"""
+ Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
+ a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
+ nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
+ information.
+
+ Args:
+ feature_extractor (`DiaFeatureExtractor`):
+ An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`DiaTokenizer`):
+ An instance of [`DiaTokenizer`]. The tokenizer is a required input.
+ audio_tokenizer (`DacModel`):
+ An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
+ """
+
+ feature_extractor_class = "DiaFeatureExtractor"
+ tokenizer_class = "DiaTokenizer"
+ audio_tokenizer_class = "DacModel"
+
+ def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
+ super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
+
+ @property
+ def model_input_names(self):
+ """
+ We no longer pass the raw audio values but the codebooks encoded by the `audio_tokenizer`.
+ Conventions may differ between audio models due to architectural choices.
+ """
+ tokenizer_input_names = self.tokenizer.model_input_names
+ audio_tokenizer_input_names = ["decoder_input_ids", "decoder_attention_mask"]
+ return list(dict.fromkeys(tokenizer_input_names + audio_tokenizer_input_names))
+
+ def __call__(
+ self,
+ text: Union[str, list[str]],
+ audio: Optional[AudioInput] = None,
+ output_labels: Optional[bool] = False,
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ):
+ """
+ Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
+ forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
+ DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
+ to the docstring of the above methods for more information.
+ """
+ if not is_torch_available():
+ raise ValueError(
+ "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
+ "find it in your environment. You can install torch via `pip install torch`."
+ )
+
+ if text is None:
+ raise ValueError("You need to specify the `text` input to process.")
+
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+
+ text_kwargs = output_kwargs["text_kwargs"]
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ common_kwargs = output_kwargs["common_kwargs"]
+
+ return_tensors = common_kwargs.pop("return_tensors", None)
+ if return_tensors != "pt":
+ raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
+
+ data = {}
+
+ # Text
+ if isinstance(text, str):
+ text = [text]
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ encodings = self.tokenizer(text, **text_kwargs)
+ data.update(encodings)
+
+ # Audio
+ delay_pattern = audio_kwargs.pop("delay_pattern", None)
+ audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
+ audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
+ audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
+ generation = audio_kwargs.pop("generation", True)
+ if (
+ audio_bos_token_id is None
+ or audio_eos_token_id is None
+ or audio_pad_token_id is None
+ or delay_pattern is None
+ ):
+ raise ValueError(
+ "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
+ "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
+ )
+
+ if generation and output_labels:
+ raise ValueError(
+ f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
+ )
+
+ batch_size = data["input_ids"].shape[0]
+ num_channels = len(delay_pattern)
+ max_delay = max(delay_pattern)
+
+ # Voice cloning generation / general training
+ if audio is not None:
+ audio = make_list_of_audio(audio)
+ input_audios = self.feature_extractor(audio, **audio_kwargs)
+
+ compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
+ max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
+
+ decoder_input_ids = []
+ decoder_attention_mask = []
+ # TODO: dac with batching is currently broken, but non-batch is working
+ # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
+ for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
+ # get current length with hop length in mind (as if it were sampled as a single audio)
+ base_pad_len = self.feature_extractor.hop_length
+ current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
+
+ encoded_sequence_len = current_audio_len // compression_rate
+ padding_len = max_encoded_sequence_len - encoded_sequence_len
+
+ # compute non-padded forward pass; one extra bos (and eos if training) is added
+ with torch.no_grad():
+ audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
+ input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
+
+ if not generation:
+ input_ids = torch.nn.functional.pad(
+ input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
+ )
+
+ # apply padding
+ # +1 for the bos within the real sequence
+ input_ids = torch.nn.functional.pad(
+ input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
+ )
+ num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
+ num_valid_inputs += 0 if generation else 1 # eos if training
+ attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
+
+ decoder_input_ids.append(input_ids)
+ decoder_attention_mask.append(attention_mask)
+
+ decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
+ decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
+ # TTS generation
+ elif generation:
+ # all bos to start with TTS
+ decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
+
+ # we preemptively add the delay
+ decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
+ else:
+ raise ValueError("If you try to train, you should provide audio data as well.")
+
+ if batch_size != decoder_input_ids.shape[0]:
+ raise ValueError(
+ f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
+ f"audio samples = {decoder_input_ids.shape[0]} instead."
+ )
+
+ # prepare shift indices per delay
+ max_seq_len = decoder_attention_mask.shape[-1]
+ max_audio_len = max_seq_len - max_delay
+ precomputed_idx = self.build_indices(
+ bsz=batch_size,
+ seq_len=max_seq_len,
+ num_channels=num_channels,
+ delay_pattern=delay_pattern,
+ revert=False,
+ )
+
+ # create delay pattern input
+ # the pad token will be used for masking which input is valid for prediction during generation
+ prefill = torch.full(
+ (batch_size, max_seq_len, num_channels),
+ fill_value=audio_pad_token_id,
+ dtype=torch.int,
+ )
+ prefill[:, :max_audio_len] = decoder_input_ids
+
+ delayed_decoder_input_ids = self.apply_audio_delay(
+ audio=prefill,
+ pad_token_id=audio_pad_token_id,
+ bos_token_id=audio_bos_token_id,
+ precomputed_idx=precomputed_idx,
+ )
+
+ data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
+
+ if output_labels:
+ # Base idea is to shift on the sequence dim
+ labels = data["decoder_input_ids"].clone()[:, 1:]
+ labels[labels == audio_pad_token_id] = -100
+ labels[labels == audio_bos_token_id] = -100
+
+ data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
+ data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
+ data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def batch_decode(
+ self,
+ decoder_input_ids: "torch.Tensor",
+ audio_prompt_len: Optional[int] = None,
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ) -> list["torch.Tensor"]:
+ """
+ Decodes a batch of audio codebook sequences into their respective audio waveforms via the
+ `audio_tokenizer`. See [`~DacModel.decode`] for more information.
+
+ Args:
+ decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
+ audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
+ """
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+
+ delay_pattern = audio_kwargs.pop("delay_pattern", None)
+ audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
+ audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
+ if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
+ raise ValueError(
+ "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
+ "and `delay_pattern`. You may have accidentally overwritten one of those."
+ )
+
+ # either decode the whole audio sequence or only the generated parts
+ if audio_prompt_len is not None:
+ audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
+ start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
+ else:
+ start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
+ # -1 for the eos token
+ end_of_generation_idx = (
+ decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
+ )
+
+ # revert delay
+ bsz, seq_len, num_channels = decoder_input_ids.shape
+ precomputed_idx = self.build_indices(
+ bsz=bsz,
+ seq_len=seq_len,
+ num_channels=num_channels,
+ delay_pattern=delay_pattern,
+ revert=True,
+ )
+
+ output_sequences = self.apply_audio_delay(
+ audio=decoder_input_ids,
+ # We do not care about these values as we cut them out
+ # with `start_of_generation_idx` and `end_of_generation_idx`
+ pad_token_id=-1,
+ bos_token_id=-1,
+ precomputed_idx=precomputed_idx,
+ ).transpose(1, 2)
+
+ # retrieve the correct sequences each
+ audios = []
+ # TODO: see above, dac doesn't work in batches yet
+ with torch.no_grad():
+ for i in range(start_of_generation_idx.shape[0]):
+ output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
+ output_i = output_i.to(self.audio_tokenizer.device)
+ audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
+ audios.append(audio_i)
+
+ return audios
+
+ def decode(
+ self,
+ decoder_input_ids: "torch.Tensor",
+ audio_prompt_len: Optional[int] = None,
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ) -> "torch.Tensor":
+ """
+ Decodes a single sequence of audio codebooks into the respective audio waveform via the
+ `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
+ """
+ if decoder_input_ids.shape[0] != 1:
+ raise ValueError(
+ f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
+ )
+
+ return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
+
+ def get_audio_prompt_len(
+ self,
+ decoder_attention_mask: "torch.Tensor",
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ) -> int:
+ """Utility function to get the audio prompt length."""
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+
+ delay_pattern = audio_kwargs.pop("delay_pattern", None)
+ if delay_pattern is None:
+ raise ValueError(
+ "To enable the utility of retrieving the prompt length for Dia, we need the "
+ "`delay_pattern`. You may have accidentally overwritten this."
+ )
+ return decoder_attention_mask.shape[1] - max(delay_pattern)
+
+ # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
+ def save_audio(
+ self,
+ audio: AudioInput,
+ saving_path: Union[str, Path, list[Union[str, Path]]],
+ **kwargs: Unpack[DiaProcessorKwargs],
+ ):
+ # TODO: @eustlb, this should be in AudioProcessor
+ if not is_soundfile_available():
+ raise ImportError("Please install `soundfile` to save audio files.")
+
+ # ensure correct audio input
+ audio = make_list_of_audio(audio)
+
+ # ensure correct saving path
+ if isinstance(saving_path, (str, Path)):
+ saving_path = [saving_path]
+ elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
+ raise ValueError("Invalid input path. Please provide a string, or a list of strings")
+
+ if len(audio) != len(saving_path):
+ raise ValueError("The number of audio and saving paths must be the same")
+
+ output_kwargs = self._merge_kwargs(
+ DiaProcessorKwargs,
+ **kwargs,
+ )
+ audio_kwargs = output_kwargs["audio_kwargs"]
+ sampling_rate = audio_kwargs["sampling_rate"]
+
+ for audio_value, p in zip(audio, saving_path):
+ if isinstance(audio_value, torch.Tensor):
+ audio_value = audio_value.cpu().float().numpy()
+ sf.write(p, audio_value, sampling_rate)
+
+ @staticmethod
+ def build_indices(
+ bsz: int,
+ seq_len: int,
+ num_channels: int,
+ delay_pattern: list[int],
+ revert: bool = False,
+ ) -> tuple["torch.Tensor", "torch.Tensor"]:
+ """
+ Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
+ or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
+ Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
+ """
+ delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
+
+ # (0..seq_len-1)
+ sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
+ # + or - delay depending if we delay or revert the delay
+ if not revert:
+ sequence_idx = sequence_idx - delay_array[None, None, :]
+ else:
+ sequence_idx = sequence_idx + delay_array[None, None, :]
+ # if delay goes over the range we clamp back to valid values
+ valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
+
+ batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
+ channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
+
+ all_idx = torch.stack(
+ [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
+ dim=1,
+ ).long()
+
+ return sequence_idx, all_idx
+
+ @staticmethod
+ def apply_audio_delay(
+ audio: "torch.Tensor",
+ pad_token_id: int,
+ bos_token_id: int,
+ precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
+ ) -> "torch.Tensor":
+ """
+ Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
+ inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
+
+ Args:
+ audio: audio tokens of shape [bsz, seq_len, num_channels]
+ pad_token_id: the PAD token
+ bos_token_id: the BOS token
+ precomputed_idx: from `build_indices`
+
+ Returns:
+ final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
+ """
+ # Move everything to the same device
+ device = audio.device
+ sequence_idx, all_idx = precomputed_idx
+ sequence_idx = sequence_idx.to(device)
+ all_idx = all_idx.to(device)
+
+ # Gather per precomputed indices
+ batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
+ gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
+
+ # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
+ mask_bos = sequence_idx < 0
+ mask_pad = sequence_idx >= audio.shape[1]
+ final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
+
+ return final_audio
+
+
+__all__ = ["DiaProcessor"]
diff --git a/src/transformers/models/dia/tokenization_dia.py b/src/transformers/models/dia/tokenization_dia.py
new file mode 100644
index 0000000000..4e205906ea
--- /dev/null
+++ b/src/transformers/models/dia/tokenization_dia.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for Dia."""
+
+from typing import Optional
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaTokenizer(PreTrainedTokenizer):
+ """
+ Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ max_length (`int`, *optional*, defaults to 1024):
+ The maximum length of the sequences when encoding. Sequences longer than this will be truncated.
+ offset (`int`, *optional*, defaults to 0):
+ The offset of the tokenizer.
+ """
+
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ pad_token: Optional[str] = "",
+ unk_token: Optional[str] = "",
+ max_length: Optional[int] = 1024,
+ offset: int = 0,
+ **kwargs,
+ ):
+ # We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well
+ pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token
+ unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token
+
+ self._utf_vocab_size = 2**8 # utf is 8 bits
+ self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")}
+ self.offset = offset
+ super().__init__(
+ unk_token=unk_token,
+ pad_token=pad_token,
+ max_length=max_length,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return self._utf_vocab_size
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> list[str]:
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+ tokens = [chr(i) for i in text.encode("utf-8")]
+ return tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+
+ if len(token) != 1:
+ token_id = None
+ else:
+ token_id = ord(token) + self.offset
+
+ return token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = chr(index - self.offset)
+ return token
+
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
+ """Converts a sequence of tokens (string) in a single string."""
+ bstring = b""
+ for token in tokens:
+ if token in self.added_tokens_decoder:
+ added_token_obj = self.added_tokens_decoder[token]
+ tok_string = str(added_token_obj).encode("utf-8")
+ elif token in self.added_tokens_encoder:
+ tok_string = token.encode("utf-8")
+ else:
+ tok_string = token.encode("utf-8") # Assume general string token
+ bstring += tok_string
+ string = bstring.decode("utf-8", errors="ignore")
+ return string
+
+ # No vocab file
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+ return ()
+
+
+__all__ = ["DiaTokenizer"]
diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py
index 79b0e9b35f..afeae13ae7 100644
--- a/src/transformers/pipelines/text_to_audio.py
+++ b/src/transformers/pipelines/text_to_audio.py
@@ -80,15 +80,21 @@ class TextToAudioPipeline(Pipeline):
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
"""
+ # Introducing the processor at load time for new behaviour
+ _load_processor = True
+
_pipeline_calls_generate = True
# Make sure the docstring is updated when the default generation config is changed
_default_generation_config = GenerationConfig(
max_new_tokens=256,
)
- def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
+ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs):
super().__init__(*args, **kwargs)
+ # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time
+ self.no_processor = no_processor
+
if self.framework == "tf":
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
@@ -117,6 +123,10 @@ class TextToAudioPipeline(Pipeline):
if sampling_rate is not None:
self.sampling_rate = sampling_rate
+ # last fallback to get the sampling rate based on processor
+ if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"):
+ self.sampling_rate = self.processor.feature_extractor.sampling_rate
+
def preprocess(self, text, **kwargs):
if isinstance(text, str):
text = [text]
@@ -136,7 +146,8 @@ class TextToAudioPipeline(Pipeline):
kwargs = new_kwargs
- output = self.tokenizer(text, **kwargs, return_tensors="pt")
+ preprocessor = self.tokenizer if self.no_processor else self.processor
+ output = preprocessor(text, **kwargs, return_tensors="pt")
return output
@@ -228,12 +239,21 @@ class TextToAudioPipeline(Pipeline):
return preprocess_params, params, postprocess_params
- def postprocess(self, waveform):
+ def postprocess(self, audio):
output_dict = {}
- if isinstance(waveform, dict):
- waveform = waveform["waveform"]
- elif isinstance(waveform, tuple):
- waveform = waveform[0]
+
+ # We directly get the waveform
+ if self.no_processor:
+ if isinstance(audio, dict):
+ waveform = audio["waveform"]
+ elif isinstance(audio, tuple):
+ waveform = audio[0]
+ else:
+ waveform = audio
+ # Or we need to postprocess to get the waveform
+ else:
+ waveform = self.processor.decode(audio)
+
output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy()
output_dict["sampling_rate"] = self.sampling_rate
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index e7adab4b1d..2a97cde3cc 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -49,6 +49,7 @@ from .tokenization_utils_base import (
TruncationStrategy,
)
from .utils import (
+ AUDIO_TOKENIZER_NAME,
CHAT_TEMPLATE_DIR,
CHAT_TEMPLATE_FILE,
LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
@@ -61,12 +62,17 @@ from .utils import (
download_url,
is_offline_mode,
is_remote_url,
+ is_torch_available,
list_repo_templates,
logging,
)
from .utils.deprecation import deprecate_kwarg
+if is_torch_available():
+ from .modeling_utils import PreTrainedAudioTokenizerBase
+
+
logger = logging.get_logger(__name__)
# Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
@@ -499,7 +505,7 @@ class ProcessorMixin(PushToHubMixin):
"""
attributes = ["feature_extractor", "tokenizer"]
- optional_attributes = ["chat_template"]
+ optional_attributes = ["chat_template", "audio_tokenizer"]
optional_call_args: list[str] = []
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
@@ -511,7 +517,19 @@ class ProcessorMixin(PushToHubMixin):
# First, extract optional attributes from kwargs if present
# Optional attributes can never be positional arguments
for optional_attribute in self.optional_attributes:
- setattr(self, optional_attribute, kwargs.pop(optional_attribute, None))
+ optional_attribute_value = kwargs.pop(optional_attribute, None)
+ setattr(self, optional_attribute, optional_attribute_value)
+
+ # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights
+ if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None:
+ proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value)
+
+ if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)):
+ raise ValueError(
+ f"Tried to use `{proper_class}` for audio tokenization. However, this class is not"
+ " registered for audio tokenization."
+ )
+
# Sanitize args and kwargs
for key in kwargs:
if key not in self.attributes:
@@ -530,21 +548,30 @@ class ProcessorMixin(PushToHubMixin):
# Check each arg is of the proper class (this will also catch a user initializing in the wrong order)
for attribute_name, arg in kwargs.items():
- class_name = getattr(self, f"{attribute_name}_class")
- # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
- class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
- if isinstance(class_name, tuple):
- proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
- else:
- proper_class = self.get_possibly_dynamic_module(class_name)
-
- if not isinstance(arg, proper_class):
- raise TypeError(
- f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected."
- )
-
+ self.check_argument_for_proper_class(attribute_name, arg)
setattr(self, attribute_name, arg)
+ def check_argument_for_proper_class(self, argument_name, argument):
+ """
+ Checks the passed argument's class against the expected transformers class. In case of an unexpected
+ mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class
+ is returned.
+ """
+ class_name = getattr(self, f"{argument_name}_class")
+ # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
+ class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
+ if isinstance(class_name, tuple):
+ proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
+ else:
+ proper_class = self.get_possibly_dynamic_module(class_name)
+
+ if not isinstance(argument, proper_class):
+ raise TypeError(
+ f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected."
+ )
+
+ return proper_class
+
def to_dict(self) -> dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
@@ -577,6 +604,8 @@ class ProcessorMixin(PushToHubMixin):
del output["feature_extractor"]
if "chat_template" in output:
del output["chat_template"]
+ if "audio_tokenizer" in output:
+ del output["audio_tokenizer"]
# Some attributes have different names but containing objects that are not simple strings
output = {
@@ -695,6 +724,7 @@ class ProcessorMixin(PushToHubMixin):
save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE
) # Legacy filename
chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
+ output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME)
processor_dict = self.to_dict()
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
@@ -737,6 +767,19 @@ class ProcessorMixin(PushToHubMixin):
"separate files using the `save_jinja_files` argument."
)
+ if self.audio_tokenizer is not None:
+ audio_tokenizer_class = self.audio_tokenizer.__class__.__name__
+ audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path
+
+ audio_tokenizer_dict = {
+ "audio_tokenizer_class": audio_tokenizer_class,
+ "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path,
+ }
+ audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n"
+
+ with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer:
+ writer.write(audio_tokenizer_json)
+
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
# `auto_map` is not specified.
if set(processor_dict.keys()) != {"processor_class"}:
@@ -774,6 +817,9 @@ class ProcessorMixin(PushToHubMixin):
Returns:
`tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object.
"""
+ # holding a copy for optionally loading the audio tokenizer (if available)
+ audio_tokenizer_kwargs = copy.deepcopy(kwargs)
+
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None)
@@ -803,16 +849,18 @@ class ProcessorMixin(PushToHubMixin):
resolved_additional_chat_template_files = {}
if os.path.isfile(pretrained_model_name_or_path):
resolved_processor_file = pretrained_model_name_or_path
- # can't load chat-template when given a file as pretrained_model_name_or_path
+ # can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path
resolved_chat_template_file = None
resolved_raw_chat_template_file = None
+ resolved_audio_tokenizer_file = None
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
processor_file = pretrained_model_name_or_path
resolved_processor_file = download_url(pretrained_model_name_or_path)
- # can't load chat-template when given a file url as pretrained_model_name_or_path
+ # can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path
resolved_chat_template_file = None
resolved_raw_chat_template_file = None
+ resolved_audio_tokenizer_file = None
else:
if is_local:
template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
@@ -899,6 +947,21 @@ class ProcessorMixin(PushToHubMixin):
)
for template_name, template_file in additional_chat_template_files.items()
}
+
+ resolved_audio_tokenizer_file = cached_file(
+ pretrained_model_name_or_path,
+ AUDIO_TOKENIZER_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
@@ -939,6 +1002,22 @@ class ProcessorMixin(PushToHubMixin):
if chat_templates:
kwargs["chat_template"] = chat_templates
+ # Same as chat template, adding as kwarg after loading the model
+ audio_tokenizer = None
+ if resolved_audio_tokenizer_file is not None:
+ with open(resolved_audio_tokenizer_file, "r", encoding="utf-8") as reader:
+ # The json contains the references we need to init the correct model
+ audio_tokenizer_references = json.load(reader)
+ audio_tokenizer_class = cls.get_possibly_dynamic_module(
+ audio_tokenizer_references["audio_tokenizer_class"]
+ )
+ audio_tokenizer_path = audio_tokenizer_references["audio_tokenizer_name_or_path"]
+
+ audio_tokenizer = audio_tokenizer_class.from_pretrained(audio_tokenizer_path, **audio_tokenizer_kwargs)
+
+ if audio_tokenizer is not None:
+ kwargs["audio_tokenizer"] = audio_tokenizer
+
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
@@ -947,7 +1026,9 @@ class ProcessorMixin(PushToHubMixin):
# In any case we need to pass `chat_template` if it is available
processor_dict = {}
if "chat_template" in kwargs:
- processor_dict = {"chat_template": kwargs.pop("chat_template")}
+ processor_dict["chat_template"] = kwargs.pop("chat_template")
+ if "audio_tokenizer" in kwargs:
+ processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer")
return processor_dict, kwargs
try:
@@ -972,6 +1053,8 @@ class ProcessorMixin(PushToHubMixin):
if "chat_template" in kwargs:
processor_dict["chat_template"] = kwargs.pop("chat_template")
+ if "audio_tokenizer" in kwargs:
+ processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer")
return processor_dict, kwargs
@@ -1276,6 +1359,7 @@ class ProcessorMixin(PushToHubMixin):
attribute_class = cls.get_possibly_dynamic_module(class_name)
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
+
return args
@staticmethod
@@ -1287,6 +1371,7 @@ class ProcessorMixin(PushToHubMixin):
transformers_module.VIDEO_PROCESSOR_MAPPING,
transformers_module.TOKENIZER_MAPPING,
transformers_module.FEATURE_EXTRACTOR_MAPPING,
+ transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
]
for lookup_location in lookup_locations:
for custom_class in lookup_location._extra_content.values():
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 7ca4c35528..4943e91e73 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -292,6 +292,7 @@ CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
+AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json"
PROCESSOR_NAME = "processor_config.json"
GENERATION_CONFIG_NAME = "generation_config.json"
MODEL_CARD_NAME = "modelcard.json"
diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py
index ea0a7581e5..834c502b1a 100644
--- a/tests/generation/test_logits_process.py
+++ b/tests/generation/test_logits_process.py
@@ -56,7 +56,12 @@ if is_torch_available():
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
- from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
+ from transformers.generation.logits_process import (
+ BarkEosPrioritizerLogitsProcessor,
+ DiaClassifierFreeGuidanceLogitsProcessor,
+ DiaEOSChannelFilterLogitsProcessor,
+ DiaEOSDelayPatternLogitsProcessor,
+ )
@require_torch
@@ -1211,3 +1216,145 @@ class LogitsProcessorTest(unittest.TestCase):
)
)
self.assertTrue(is_close)
+
+ def test_dia_classifier_free_guidance(self):
+ input_ids = torch.LongTensor([[0]])
+ logits_uncond = torch.tensor([[1.0, 0, 1.5]])
+ logits_cond = torch.tensor([[1.0, 1.0, 1.0]])
+
+ # base cfg with conditioned as center
+ cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5)
+ out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
+
+ res = logits_cond + 1.5 * (logits_cond - logits_uncond)
+
+ self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
+ self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
+ self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
+
+ # additional top k (on cond logits)
+ cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5, guidance_top_k=1)
+ out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
+
+ res = logits_cond + 1.5 * (logits_cond - logits_uncond)
+ mask = res == res.max()
+ res = logits_cond.clone()
+ res[~mask.bool()] = -float("inf")
+
+ self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
+ self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
+ self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
+
+ def test_dia_channel_filter(self):
+ eos = 2
+ bsz, channels, vocab = 2, 2, 4
+
+ input_ids = torch.LongTensor([[0]])
+ logits = torch.zeros(size=(bsz, channels, vocab)).view(bsz * channels, vocab)
+ logits[0, eos] = 1 # Eos max (forced)
+ logits[1, eos] = 1 # Eos max (forced) but not channel 0
+
+ channel_filter = DiaEOSChannelFilterLogitsProcessor(num_channels=channels, eos_token_id=eos)
+ out = channel_filter(input_ids, logits).view(bsz, channels, vocab)
+
+ for i in range(vocab):
+ if i > eos:
+ # special tokens are not to be predicted
+ self.assertTrue((out[:, :, i] == -float("inf")).all())
+ elif i == eos:
+ # Eos forced on channel 0
+ self.assertTrue(out[0, 0, i] == 1)
+ # Eos suppressed on everything else (even if max before)
+ self.assertTrue(out[0, 1, i] == -float("inf"))
+ self.assertTrue((out[1, :, i] == -float("inf")).all())
+ else:
+ # Eos forced on channel 0
+ self.assertTrue(out[0, 0, i] == -float("inf"))
+ # previous values
+ self.assertTrue(out[0, 1, i] == 0)
+ self.assertTrue((out[1, :, i] == 0).all())
+
+ def test_dia_delay_pattern(self):
+ def check_eos_logits(out, logits, batch, channel, eos):
+ for i in range(vocab):
+ if i == eos:
+ self.assertTrue(out[batch, channel, i] == 0)
+ else:
+ self.assertTrue(out[batch, channel, i] == -float("inf"))
+
+ for c in range(channel):
+ if c != channel:
+ self.assertTrue((out[batch, c] == logits[batch, c]).all())
+
+ eos = 2
+ delay_pattern = [0, 2, 3]
+ max_generation_len = 10
+ bsz, channels, vocab = 2, 3, 4
+
+ input_ids = torch.LongTensor([[0]])
+ logits = torch.zeros(size=(bsz, channels, vocab))
+ # Ensure that argmax can not result in eos
+ logits[:, :, eos] = -1
+
+ delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
+ delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
+ )
+ out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
+
+ # Nothing should happen except for init of some attributes
+ self.assertTrue((out == logits).all())
+ self.assertTrue((~delay_pattern_processor.active_batches).all())
+ self.assertTrue(
+ (delay_pattern_processor.delay_pattern == torch.tensor([delay_pattern for _ in range(bsz)])).all()
+ )
+
+ # Make first batch end
+ logits[0, 0, eos] = 1
+
+ # Go through the complete delay pattern
+ for i in range(max(delay_pattern) + 1):
+ out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
+
+ # no delay should kick in
+ if i == 1:
+ self.assertTrue((out == logits).all())
+ else:
+ j = i if i == 0 else i - 1
+ check_eos_logits(out=out, logits=logits, batch=0, channel=j, eos=eos)
+ self.assertTrue((out[1] == logits[1]).all())
+ self.assertTrue(delay_pattern_processor.active_batches[0])
+ self.assertFalse(delay_pattern_processor.active_batches[1])
+ self.assertTrue(
+ (
+ delay_pattern_processor.delay_pattern[0]
+ == torch.tensor([delay - (i + 1) for delay in delay_pattern])
+ ).all()
+ )
+ self.assertTrue((delay_pattern_processor.delay_pattern[1] == torch.tensor(delay_pattern)).all())
+
+ # Make second batch end
+ logits[1, 0, eos] = 1
+
+ # Just to check if other batches could work
+ out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
+
+ self.assertTrue((out[0] == logits[0]).all())
+ self.assertTrue(delay_pattern_processor.active_batches.all())
+ self.assertTrue(
+ (delay_pattern_processor.delay_pattern[0] == torch.tensor([delay - 5 for delay in delay_pattern])).all()
+ )
+ self.assertTrue(
+ (delay_pattern_processor.delay_pattern[1] == torch.tensor([delay - 1 for delay in delay_pattern])).all()
+ )
+
+ # Last check on max generation length reached (with delay in mind until last channel produces eos)
+ input_ids = torch.LongTensor([[0] * (max_generation_len - max(delay_pattern) - 1)])
+ delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
+ delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
+ )
+ out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
+
+ check_eos_logits(out=out, logits=logits, batch=0, channel=0, eos=eos)
+ check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos)
+ self.assertTrue(delay_pattern_processor.active_batches.all())
+ self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all())
diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py
index 2a1bc30dbb..60500001a3 100644
--- a/tests/models/auto/test_processor_auto.py
+++ b/tests/models/auto/test_processor_auto.py
@@ -26,6 +26,7 @@ import transformers
from transformers import (
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
+ MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
PROCESSOR_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
@@ -265,6 +266,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
+ if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
+ del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_processor_conflict(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
@@ -317,6 +320,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
+ if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
+ del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_processor_with_extra_attributes(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
@@ -356,6 +361,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
+ if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
+ del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_dynamic_processor_with_specific_dynamic_subcomponents(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
@@ -390,6 +397,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
+ if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
+ del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_auto_processor_creates_tokenizer(self):
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert")
diff --git a/tests/models/dia/__init__.py b/tests/models/dia/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/dia/test_feature_extraction_dia.py b/tests/models/dia/test_feature_extraction_dia.py
new file mode 100644
index 0000000000..6243dc4791
--- /dev/null
+++ b/tests/models/dia/test_feature_extraction_dia.py
@@ -0,0 +1,231 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for the Dia feature extractor."""
+
+import itertools
+import random
+import unittest
+
+import numpy as np
+
+from transformers import DiaFeatureExtractor
+from transformers.testing_utils import require_torch
+from transformers.utils.import_utils import is_torch_available
+
+from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
+
+
+if is_torch_available():
+ import torch
+
+
+global_rng = random.Random()
+
+
+# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
+def floats_list(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ values = []
+ for batch_idx in range(shape[0]):
+ values.append([])
+ for _ in range(shape[1]):
+ values[-1].append(rng.random() * scale)
+
+ return values
+
+
+@require_torch
+class DiaFeatureExtractionTester:
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.__init__
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ min_seq_length=400,
+ max_seq_length=2000,
+ feature_size=1,
+ padding_value=0.0,
+ sampling_rate=16000,
+ hop_length=512,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.min_seq_length = min_seq_length
+ self.max_seq_length = max_seq_length
+ self.hop_length = hop_length
+ self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
+ self.feature_size = feature_size
+ self.padding_value = padding_value
+ self.sampling_rate = sampling_rate
+
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.prepare_feat_extract_dict
+ def prepare_feat_extract_dict(self):
+ return {
+ "feature_size": self.feature_size,
+ "padding_value": self.padding_value,
+ "sampling_rate": self.sampling_rate,
+ "hop_length": self.hop_length,
+ }
+
+ # Copied from tests.models.encodec.test_feature_extraction_encodec.EnCodecFeatureExtractionTester.prepare_inputs_for_common
+ def prepare_inputs_for_common(self, equal_length=False, numpify=False):
+ def _flatten(list_of_lists):
+ return list(itertools.chain(*list_of_lists))
+
+ if equal_length:
+ audio_inputs = floats_list((self.batch_size, self.max_seq_length))
+ else:
+ # make sure that inputs increase in size
+ audio_inputs = [
+ _flatten(floats_list((x, self.feature_size)))
+ for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
+ ]
+
+ if numpify:
+ audio_inputs = [np.asarray(x) for x in audio_inputs]
+
+ return audio_inputs
+
+
+@require_torch
+class DiaFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
+ feature_extraction_class = DiaFeatureExtractor
+
+ def setUp(self):
+ self.feat_extract_tester = DiaFeatureExtractionTester(self)
+
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_call
+ def test_call(self):
+ # Tests that all call wrap to encode_plus and batch_encode_plus
+ feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ # create three inputs of length 800, 1000, and 1200
+ audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
+ np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
+
+ # Test not batched input
+ encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values
+ encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values
+ self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
+
+ # Test batched
+ encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values
+ encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values
+ for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
+ self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
+
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_double_precision_pad
+ def test_double_precision_pad(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ np_audio_inputs = np.random.rand(100).astype(np.float64)
+ py_audio_inputs = np_audio_inputs.tolist()
+
+ for inputs in [py_audio_inputs, np_audio_inputs]:
+ np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
+ self.assertTrue(np_processed.input_values.dtype == np.float32)
+ pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
+ self.assertTrue(pt_processed.input_values.dtype == torch.float32)
+
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest._load_datasamples
+ def _load_datasamples(self, num_samples):
+ from datasets import load_dataset
+
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ # automatic decoding with librispeech
+ audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
+
+ return [x["array"] for x in audio_samples]
+
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_integration with Dac->Dia
+ def test_integration(self):
+ # fmt: off
+ EXPECTED_INPUT_VALUES = torch.tensor(
+ [ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03,
+ 1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04,
+ 9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03,
+ 1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04,
+ 1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03,
+ 7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03,
+ 8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04,
+ 1.0986328e-03, 2.1057129e-03]
+ )
+ # fmt: on
+ input_audio = self._load_datasamples(1)
+ feature_extractor = DiaFeatureExtractor()
+ input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"]
+ self.assertEqual(input_values.shape, (1, 1, 93696))
+ torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
+ audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32)
+ torch.testing.assert_close(input_values[0, 0, -46:-16], audio_input_end, rtol=1e-4, atol=1e-4)
+
+ def test_integration_stereo(self):
+ # fmt: off
+ EXPECTED_INPUT_VALUES = torch.tensor(
+ [2.3804e-03, 2.0752e-03, 1.9836e-03, 2.1057e-03, 1.6174e-03,
+ 3.0518e-04, 9.1553e-05, 3.3569e-04, 9.7656e-04, 1.8311e-03,
+ 2.0142e-03, 2.1057e-03, 1.7395e-03, 4.5776e-04, -3.9673e-04,
+ 4.5776e-04, 1.0071e-03, 9.1553e-05, 4.8828e-04, 1.1597e-03,
+ 7.3242e-04, 9.4604e-04, 1.8005e-03, 1.8311e-03, 8.8501e-04,
+ 4.2725e-04, 4.8828e-04, 7.3242e-04, 1.0986e-03, 2.1057e-03]
+ )
+ # fmt: on
+ input_audio = self._load_datasamples(1)
+ input_audio = [np.tile(input_audio[0][None], reps=(2, 1))]
+ feature_extractor = DiaFeatureExtractor(feature_size=2)
+ input_values = feature_extractor(input_audio, return_tensors="pt").input_values
+ self.assertEqual(input_values.shape, (1, 1, 93696))
+ torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
+
+ # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_truncation_and_padding with Dac->Dia
+ def test_truncation_and_padding(self):
+ input_audio = self._load_datasamples(2)
+ # would be easier if the stride was like
+ feature_extractor = DiaFeatureExtractor()
+
+ # pad and trunc raise an error ?
+ with self.assertRaisesRegex(
+ ValueError,
+ "^Both padding and truncation were set. Make sure you only set one.$",
+ ):
+ truncated_outputs = feature_extractor(
+ input_audio, padding="max_length", truncation=True, return_tensors="pt"
+ ).input_values
+
+ # force truncate to max_length
+ truncated_outputs = feature_extractor(
+ input_audio, truncation=True, max_length=48000, return_tensors="pt"
+ ).input_values
+ self.assertEqual(truncated_outputs.shape, (2, 1, 48128))
+
+ # pad:
+ padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values
+ self.assertEqual(padded_outputs.shape, (2, 1, 93696))
+
+ # force pad to max length
+ truncated_outputs = feature_extractor(
+ input_audio, padding="max_length", max_length=100000, return_tensors="pt"
+ ).input_values
+ self.assertEqual(truncated_outputs.shape, (2, 1, 100352))
+
+ # force no pad
+ with self.assertRaisesRegex(
+ ValueError,
+ "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$",
+ ):
+ truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values
+
+ truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values
+ self.assertEqual(truncated_outputs.shape, (1, 1, 93680))
diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py
new file mode 100644
index 0000000000..f9427160c2
--- /dev/null
+++ b/tests/models/dia/test_modeling_dia.py
@@ -0,0 +1,752 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Dia model."""
+
+import copy
+import pathlib
+import tempfile
+import unittest
+
+import pytest
+
+from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from transformers.testing_utils import (
+ cleanup,
+ is_flaky,
+ require_torch,
+ require_torch_accelerator,
+ require_torch_sdpa,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
+from transformers.utils.import_utils import is_datasets_available
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_datasets_available():
+ from datasets import Audio, load_dataset
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ DiaForConditionalGeneration,
+ DiaModel,
+ DiaProcessor,
+ PretrainedConfig,
+ PreTrainedModel,
+ )
+ from transformers.cache_utils import (
+ Cache,
+ StaticCache,
+ )
+ from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder
+
+if is_torchaudio_available():
+ import torchaudio
+
+if is_soundfile_available():
+ import soundfile as sf
+
+
+@require_torch
+class DiaModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=3, # need batch_size != num_hidden_layers
+ seq_length=7,
+ max_length=50,
+ is_training=True,
+ vocab_size=100,
+ hidden_size=16,
+ intermediate_size=37,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ head_dim=8,
+ decoder_hidden_size=32, # typically larger than encoder
+ hidden_act="silu",
+ eos_token_id=97, # special tokens all occur after eos
+ pad_token_id=98,
+ bos_token_id=99,
+ delay_pattern=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.max_length = max_length
+ self.is_training = is_training
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.decoder_hidden_size = decoder_hidden_size
+ self.hidden_act = hidden_act
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ # Set default delay pattern if not provided
+ self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2]
+ self.num_channels = len(self.delay_pattern)
+
+ def get_config(self):
+ encoder_config = DiaEncoderConfig(
+ max_position_embeddings=self.max_length,
+ num_hidden_layers=self.num_hidden_layers,
+ hidden_size=self.hidden_size,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing
+ head_dim=self.head_dim,
+ intermediate_size=self.intermediate_size,
+ vocab_size=self.vocab_size,
+ hidden_act=self.hidden_act,
+ )
+
+ decoder_config = DiaDecoderConfig(
+ max_position_embeddings=self.max_length,
+ num_hidden_layers=self.num_hidden_layers,
+ hidden_size=self.decoder_hidden_size,
+ intermediate_size=self.intermediate_size,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=1, # GQA
+ head_dim=self.head_dim,
+ cross_num_attention_heads=self.num_attention_heads,
+ cross_head_dim=self.head_dim,
+ cross_num_key_value_heads=1, # GQA
+ cross_hidden_size=self.hidden_size, # match encoder hidden size
+ vocab_size=self.vocab_size,
+ hidden_act=self.hidden_act,
+ num_channels=self.num_channels,
+ )
+
+ config = DiaConfig(
+ encoder_config=encoder_config,
+ decoder_config=decoder_config,
+ eos_token_id=self.eos_token_id,
+ pad_token_id=self.pad_token_id,
+ bos_token_id=self.bos_token_id,
+ delay_pattern=self.delay_pattern,
+ )
+
+ return config
+
+ def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]:
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = input_ids.ne(self.pad_token_id)
+
+ decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size)
+ decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id)
+
+ config = self.get_config()
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ }
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]:
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def create_and_check_model_forward(self, config, inputs_dict):
+ model = DiaModel(config=config).to(torch_device).eval()
+
+ input_ids = inputs_dict["input_ids"]
+ decoder_input_ids = inputs_dict["decoder_input_ids"]
+
+ # first forward pass
+ last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
+
+ self.parent.assertTrue(
+ last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size)
+ )
+
+ def check_encoder_decoder_model_standalone(self, config, inputs_dict):
+ model = DiaModel(config=config).to(torch_device).eval()
+ outputs = model(**inputs_dict)
+
+ encoder_last_hidden_state = outputs.encoder_last_hidden_state
+ last_hidden_state = outputs.last_hidden_state
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ encoder = model.get_encoder()
+ encoder.save_pretrained(tmpdirname)
+ encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device)
+
+ encoder_last_hidden_state_2 = encoder(
+ input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
+ )[0]
+
+ self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ decoder = model.get_decoder()
+ decoder.save_pretrained(tmpdirname)
+ decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device)
+
+ last_hidden_state_2 = decoder(
+ input_ids=inputs_dict["decoder_input_ids"],
+ attention_mask=inputs_dict["decoder_attention_mask"],
+ encoder_hidden_states=encoder_last_hidden_state,
+ )[0]
+
+ self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3)
+
+
+@require_torch
+class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else ()
+ # We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate`
+ all_generative_model_classes = (DiaForConditionalGeneration,)
+ # TODO: support new pipeline behavior in tests
+ pipeline_model_mapping = {}
+ # pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {}
+ test_pruning = False
+ test_head_masking = False
+ test_resize_embeddings = False
+ is_encoder_decoder = True
+ # Indicates VLMs usually but there are many audio models which are also composite
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = DiaModelTester(self)
+ # Skipping `has_text_modality` but manually testing down below
+ self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig)
+ self.skip_non_greedy_generate()
+
+ def skip_non_greedy_generate(self):
+ skippable_tests = [
+ "test_sample_generate_dict_output", # return sequences > 1
+ "test_beam",
+ "test_group_beam",
+ "test_constrained_beam",
+ "test_contrastive",
+ "test_assisted",
+ "test_dola",
+ "test_prompt_lookup",
+ "test_model_parallel_beam_search",
+ "test_generate_without_input_ids",
+ "test_generate_with_head_masking",
+ ]
+
+ for test in skippable_tests:
+ if self._testMethodName.startswith(test):
+ self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")
+
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ """Overriden to account for the 2D flattened structure"""
+ inputs_dict = copy.deepcopy(inputs_dict)
+
+ if return_labels:
+ inputs_dict["labels"] = torch.ones(
+ (
+ self.model_tester.batch_size * self.model_tester.num_channels,
+ self.model_tester.seq_length,
+ ),
+ dtype=torch.long,
+ device=torch_device,
+ )
+
+ return inputs_dict
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ # Manual testing because of composite configs
+ config = self.model_tester.prepare_config_and_inputs()[0]
+ self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist")
+ self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist")
+
+ def test_model_forward(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model_forward(*config_and_inputs)
+
+ @is_flaky
+ def test_encoder_decoder_model_standalone(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
+
+ # Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config
+ # + additional special cases where 3D x 2D meshes confuse the expected shape
+ def _check_logits(self, batch_size, logits, config):
+ batch_size *= len(config.delay_pattern) # Account for flattening
+ vocab_size = config.decoder_config.vocab_size
+ self.assertIsInstance(logits, tuple)
+ self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
+ # vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
+ vocab_diff = vocab_size - logits[0].shape[-1]
+ self.assertTrue(vocab_diff in [0, 1])
+ self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
+
+ def _check_attentions_for_generate(
+ self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
+ ):
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
+ [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
+ )
+ self.assertEqual(len(attentions), (output_length - prompt_length))
+
+ use_cache = decoder_past_key_values is not None
+ has_static_cache = isinstance(decoder_past_key_values, StaticCache)
+
+ # When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
+ # token(s)
+ for generated_length, iter_attentions in enumerate(attentions):
+ # regardless of using cache, the first forward pass will have the full prompt as input
+ if use_cache and generated_length > 0:
+ model_input_length = 1
+ else:
+ model_input_length = prompt_length + generated_length
+ query_length = (
+ prompt_length + generated_length
+ if not has_static_cache
+ else decoder_past_key_values.get_max_cache_shape()
+ )
+
+ expected_shape = (
+ batch_size,
+ config.decoder_config.num_attention_heads, # Decoder config
+ model_input_length,
+ query_length,
+ )
+ # check attn size
+ self.assertListEqual(
+ [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
+ )
+
+ def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
+ # Encoder config
+ encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length)
+ self.assertIsInstance(attentions, tuple)
+ self.assertListEqual(
+ [layer_attentions.shape for layer_attentions in attentions],
+ [encoder_expected_shape] * len(attentions),
+ )
+
+ def _check_hidden_states_for_generate(
+ self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
+ ):
+ self.assertIsInstance(hidden_states, tuple)
+ self.assertListEqual(
+ [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
+ [True] * len(hidden_states),
+ )
+ self.assertEqual(len(hidden_states), (output_length - prompt_length))
+
+ # When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
+ # new token(s)
+ for generated_length, iter_hidden_states in enumerate(hidden_states):
+ # regardless of using cache, the first forward pass will have the full prompt as input
+ if use_cache and generated_length > 0:
+ model_input_length = 1
+ else:
+ model_input_length = prompt_length + generated_length
+
+ # check hidden size
+ # we can have different hidden sizes between encoder and decoder --> check both
+ expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size)
+ expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size)
+ self.assertTrue(
+ [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
+ == [expected_shape_encoder] * len(iter_hidden_states)
+ or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
+ == [expected_shape_decoder] * len(iter_hidden_states)
+ )
+
+ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
+ # Encoder config
+ encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size)
+ self.assertIsInstance(hidden_states, tuple)
+ self.assertListEqual(
+ [layer_hidden_states.shape for layer_hidden_states in hidden_states],
+ [encoder_expected_shape] * len(hidden_states),
+ )
+
+ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
+ self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
+
+ # we need the decoder config here
+ config = config.decoder_config
+
+ # (batch, head, seq_length, head_features)
+ expected_shape = (
+ batch_size,
+ config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
+ cache_length,
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads,
+ )
+
+ if isinstance(decoder_past_key_values, Cache):
+ self.assertListEqual(
+ [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
+ [expected_shape] * len(decoder_past_key_values.key_cache),
+ )
+ self.assertListEqual(
+ [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
+ [expected_shape] * len(decoder_past_key_values.value_cache),
+ )
+
+ def _check_scores(self, batch_size, scores, generated_length, config):
+ # Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab)
+ vocab_size = config.decoder_config.vocab_size
+ expected_shape = (batch_size * len(config.delay_pattern), vocab_size)
+ self.assertIsInstance(scores, tuple)
+ self.assertEqual(len(scores), generated_length)
+ self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
+
+ @require_torch_sdpa
+ def test_sdpa_can_dispatch_composite_models(self):
+ """
+ Overwritten as it relies on hardcoded namings atm - checking for our case here specifically
+ """
+ for model_class in self.all_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname)
+
+ sub_models_supporting_sdpa = [
+ (module._supports_sdpa or module._supports_attention_backend)
+ for name, module in model.named_modules()
+ if isinstance(module, PreTrainedModel) and name != ""
+ ]
+ supports_sdpa_all_modules = (
+ all(sub_models_supporting_sdpa)
+ if len(sub_models_supporting_sdpa) > 0
+ else (model._supports_sdpa or model._supports_attention_backend)
+ )
+
+ if not supports_sdpa_all_modules:
+ with self.assertRaises(ValueError):
+ model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
+ else:
+ model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
+ for key in model_sdpa.config:
+ if isinstance(getattr(model_sdpa.config, key), PretrainedConfig):
+ sub_config = getattr(model_sdpa.config, key)
+ self.assertTrue(sub_config._attn_implementation == "sdpa")
+
+ @pytest.mark.generate
+ @unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.")
+ def test_generate_continue_from_past_key_values(self):
+ """Only a small change due to the expected shapes"""
+ # Tests that we can continue generating from past key values, returned from a previous `generate` call
+ for model_class in self.all_generative_model_classes:
+ config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Let's make it always:
+ # 1. use cache (for obvious reasons)
+ # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
+ # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
+ # continuation would force it to generate beyond an EOS token)
+ # 3. ignore `token_type_ids` for simplicity
+ # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
+ # active by default on some models
+ # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
+ # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
+ # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
+ # with cache, what is considered a prompt is different in the two cases.
+
+ if "token_type_ids" in inputs:
+ del inputs["token_type_ids"]
+
+ model = model_class(config).to(torch_device)
+ model.eval()
+
+ generate_kwargs = {
+ "pad_token_id": -1,
+ "eos_token_id": -1,
+ "forced_eos_token_id": None,
+ "encoder_no_repeat_ngram_size": 0,
+ "use_cache": True,
+ "do_sample": False,
+ "return_dict_in_generate": True,
+ "output_scores": True,
+ }
+
+ # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
+ outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
+
+ # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
+ # inputs may need to be tweaked across `generate` calls (like the attention mask).
+ outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
+
+ # Continue from the tokens generated above, preparing the inputs accordingly
+ inputs["past_key_values"] = outputs_cached.past_key_values
+ new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test
+ inputs["decoder_input_ids"] = outputs_cached.sequences
+ if "decoder_attention_mask" in inputs:
+ inputs["decoder_attention_mask"] = torch.nn.functional.pad(
+ inputs["decoder_attention_mask"],
+ (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
+ mode="constant",
+ value=1,
+ )
+
+ first_caches_scores = outputs_cached.scores
+ outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
+ full_cached_scores = first_caches_scores + outputs_cached.scores
+ outputs_cached.scores = full_cached_scores
+
+ # The two sets of generated text and past kv should be equal to each other
+ self._check_similar_generate_outputs(outputs, outputs_cached)
+ for layer_idx in range(len(outputs_cached.past_key_values)):
+ for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
+ self.assertTrue(
+ torch.allclose(
+ outputs.past_key_values[layer_idx][kv_idx],
+ outputs_cached.past_key_values[layer_idx][kv_idx],
+ )
+ )
+
+ @unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
+ def test_past_key_values_format(self, custom_all_cache_shapes=None):
+ pass
+
+ @unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip(
+ reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings"
+ )
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip(reason="Theoretically works but kernel library causes issues.")
+ def test_torchscript_output_hidden_state(self):
+ pass
+
+ @unittest.skip(reason="Theoretically works but kernel library causes issues.")
+ def test_torchscript_simple(self):
+ pass
+
+ @unittest.skip(reason="Encoder-Decoder cache can not be initialized.")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+
+class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
+ """
+ See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests
+
+ NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia
+ (It doesn't change the behaviour as we cut by the eos token position)
+ """
+
+ def setUp(self):
+ # it's a dummy ckpt but should suffice for testing purposes
+ self.model_checkpoint = "AntonV/Dia-1.6B"
+ self.sampling_rate = 44100
+
+ # prepare audio
+ librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate))
+ audio_sample_1 = librispeech_dummy[-1]["audio"]["array"]
+ audio_sample_2 = librispeech_dummy[-2]["audio"]["array"]
+ # 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia
+ dac_chunk_len = 512
+ self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3"
+ self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3"
+ sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate)
+ sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate)
+
+ def tearDown(self):
+ pathlib.Path(self.audio_prompt_1_path).unlink()
+ pathlib.Path(self.audio_prompt_2_path).unlink()
+ cleanup(torch_device, gc_collect=True)
+
+ @slow
+ @require_torch_accelerator
+ def test_dia_model_integration_generate_tts(self):
+ text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
+ processor = DiaProcessor.from_pretrained(self.model_checkpoint)
+ inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
+
+ model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
+ outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
+
+ # fmt: off
+ EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026],
+ [ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026],
+ [ 568, 804, 10, 674, 364, 981, 728, 1026, 1026],
+ [ 568, 804, 10, 674, 364, 981, 741, 550, 1026],
+ [ 568, 804, 10, 674, 364, 981, 568, 378, 90],
+ [1024, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 804, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 1024, 10, 674, 364, 981, 568, 378, 731],
+ [1025, 1025, 1024, 674, 364, 981, 568, 378, 731],
+ [1025, 1025, 1025, 1024, 364, 981, 568, 378, 731],
+ [1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731],
+ [1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]],
+
+ [[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026],
+ [ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026],
+ [ 592, 402, 386, 347, 362, 981, 728, 1026, 1026],
+ [ 592, 402, 721, 728, 327, 981, 741, 550, 1026],
+ [ 592, 402, 721, 728, 460, 62, 676, 378, 90],
+ [1024, 402, 721, 728, 837, 595, 195, 982, 784],
+ [1025, 402, 721, 677, 497, 102, 692, 24, 330],
+ [1025, 402, 721, 677, 511, 102, 503, 871, 609],
+ [1025, 402, 721, 677, 511, 96, 801, 871, 894],
+ [1025, 402, 721, 677, 511, 745, 314, 498, 775],
+ [1025, 402, 721, 677, 511, 745, 314, 498, 105],
+ [1025, 402, 721, 677, 511, 745, 314, 861, 889],
+ [1025, 893, 721, 677, 511, 744, 314, 871, 353],
+ [1025, 1024, 888, 677, 511, 744, 314, 871, 332],
+ [1025, 1025, 1024, 518, 511, 744, 314, 871, 366],
+ [1025, 1025, 1025, 1024, 611, 744, 314, 871, 366],
+ [1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366],
+ [1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]])
+ # fmt: on
+
+ torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS)
+
+ @slow
+ @require_torch_accelerator
+ def test_dia_model_integration_generate_audio_context(self):
+ text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
+ audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy()
+ audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy()
+ audio = [audio_sample_1, audio_sample_2]
+
+ processor = DiaProcessor.from_pretrained(self.model_checkpoint)
+ inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
+
+ model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
+ # dia has right padding while we have left padding (for faster prefill)
+ # additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings)
+ outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False)
+ outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False)
+
+ # fmt: off
+ EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026],
+ [ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026],
+ [ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026],
+ [ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026],
+ [ 330, 567, 530, 753, 607, 179, 954, 242, 1026],
+ [ 330, 627, 6, 1010, 500, 189, 598, 858, 247],
+ [1024, 432, 480, 530, 122, 3, 788, 149, 814],
+ [1025, 875, 826, 458, 98, 540, 181, 122, 608],
+ [1025, 495, 840, 413, 337, 784, 591, 150, 1017],
+ [1025, 808, 189, 137, 445, 0, 227, 658, 345],
+ [1025, 397, 89, 753, 1016, 173, 984, 0, 910],
+ [1025, 875, 460, 934, 50, 335, 670, 818, 722],
+ [1025, 875, 460, 762, 119, 372, 503, 858, 584],
+ [1025, 348, 555, 475, 469, 458, 963, 41, 664],
+ [1025, 1024, 852, 683, 761, 193, 595, 895, 885],
+ [1025, 1025, 1024, 135, 761, 902, 163, 623, 385],
+ [1025, 1025, 1025, 1024, 852, 282, 581, 623, 70],
+ [1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977],
+ [1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
+
+ EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026],
+ [ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026],
+ [ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026],
+ [ 315, 249, 678, 868, 899, 257, 950, 1026, 1026],
+ [ 315, 249, 217, 471, 292, 908, 196, 469, 1026],
+ [ 315, 249, 825, 771, 839, 802, 633, 590, 531],
+ [1024, 249, 150, 53, 126, 76, 794, 626, 442],
+ [1025, 249, 825, 218, 359, 864, 526, 626, 770],
+ [1025, 249, 150, 137, 530, 845, 877, 600, 111],
+ [1025, 249, 150, 287, 730, 991, 135, 259, 39],
+ [1025, 249, 825, 104, 198, 1020, 719, 625, 208],
+ [1025, 249, 825, 997, 602, 256, 859, 322, 518],
+ [1025, 668, 825, 979, 584, 256, 98, 665, 589],
+ [1025, 954, 458, 54, 206, 52, 244, 822, 599],
+ [1025, 1024, 104, 914, 435, 579, 860, 92, 661],
+ [1025, 1025, 1024, 848, 126, 74, 304, 92, 753],
+ [1025, 1025, 1025, 1024, 362, 376, 304, 586, 753],
+ [1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83],
+ [1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79],
+ [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
+ # fmt: on
+
+ torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1)
+ torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding
diff --git a/tests/models/dia/test_processor_dia.py b/tests/models/dia/test_processor_dia.py
new file mode 100644
index 0000000000..8ce15f4330
--- /dev/null
+++ b/tests/models/dia/test_processor_dia.py
@@ -0,0 +1,269 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from transformers import DacModel, DiaFeatureExtractor, DiaProcessor, DiaTokenizer
+from transformers.testing_utils import require_torch
+from transformers.utils import is_torch_available
+
+
+if is_torch_available:
+ import torch
+
+
+# Copied from tests.utils.test_modeling_utils.check_models_equal
+def check_models_equal(model1, model2):
+ models_are_equal = True
+ for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
+ if model1_p.data.ne(model2_p.data).sum() > 0:
+ models_are_equal = False
+
+ return models_are_equal
+
+
+@require_torch
+class DiaProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.checkpoint = "AntonV/Dia-1.6B"
+ self.audio_tokenizer_checkpoint = "descript/dac_44khz"
+ self.tmpdirname = tempfile.mkdtemp()
+
+ # Audio tokenizer is a bigger model so we will reuse this if possible
+ self.processor = DiaProcessor(
+ tokenizer=self.get_tokenizer(),
+ feature_extractor=self.get_feature_extractor(),
+ audio_tokenizer=self.get_audio_tokenizer(),
+ )
+
+ # Default audio values based on Dia and Dac
+ self.pad_id = 1025
+ self.bos_id = 1026
+ self.dac_chunk_len = 512
+ self.delay_pattern = [0, 8, 9, 10, 11, 12, 13, 14, 15]
+
+ def get_tokenizer(self, **kwargs):
+ return DiaTokenizer.from_pretrained(self.checkpoint, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return DiaFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
+
+ def get_audio_tokenizer(self, **kwargs):
+ return DacModel.from_pretrained(self.audio_tokenizer_checkpoint, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+ del self.processor
+
+ def test_save_load_pretrained_default(self):
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+ audio_tokenizer = self.get_audio_tokenizer()
+
+ processor = DiaProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer
+ )
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = DiaProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, DiaTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
+
+ self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer.__class__.__name__)
+ self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer.name_or_path)
+ self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer))
+ self.assertIsInstance(processor.audio_tokenizer, DacModel)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = DiaProcessor(
+ tokenizer=self.get_tokenizer(),
+ feature_extractor=self.get_feature_extractor(),
+ audio_tokenizer=self.get_audio_tokenizer(),
+ )
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer()
+ feature_extractor_add_kwargs = self.get_feature_extractor()
+ audio_tokenizer_add_kwargs = self.get_audio_tokenizer()
+
+ processor = DiaProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, DiaTokenizer)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
+
+ self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer_add_kwargs.__class__.__name__)
+ self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer_add_kwargs.name_or_path)
+ self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer_add_kwargs))
+ self.assertIsInstance(processor.audio_tokenizer, DacModel)
+
+ def test_model_input_names(self):
+ tokenizer = self.get_tokenizer()
+
+ self.assertListEqual(
+ self.processor.model_input_names,
+ list(dict.fromkeys(tokenizer.model_input_names + ["decoder_input_ids", "decoder_attention_mask"])),
+ msg="`processor` model input names do not match the expected names.",
+ )
+
+ def test_tokenize(self):
+ tokenizer = self.get_tokenizer()
+ random_text = ["This is a processing test for tokenization", "[S1] Dia template style [S2] Nice"]
+
+ input_tokenizer = tokenizer(random_text, padding=True, return_tensors="pt")
+ input_processor = self.processor(random_text)
+
+ for key in input_tokenizer.keys():
+ self.assertTrue((input_tokenizer[key] == input_processor[key]).all())
+
+ def test_no_audio(self):
+ random_text = ["Dummy Input"] * 2
+ input_processor = self.processor(random_text)
+ audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
+
+ # full mask with +1 for bos
+ self.assertTrue(audio_mask.sum() == (max(self.delay_pattern) + 1) * len(random_text))
+ self.assertTrue(
+ audio_tokens.shape
+ == (
+ len(random_text),
+ max(self.delay_pattern) + 1,
+ len(self.delay_pattern),
+ )
+ )
+
+ for channel_idx, delay in enumerate(self.delay_pattern):
+ expected_sequence = torch.ones(size=(audio_tokens.shape[:-1])) * self.pad_id
+ expected_sequence[:, : delay + 1] = self.bos_id
+ self.assertTrue((audio_tokens[..., channel_idx] == expected_sequence).all())
+
+ def test_audio(self):
+ audio_tokenizer = self.get_audio_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ random_text = ["Dummy Input"] * 2
+ # Dac only starts accepting audio from a certain length (ensured via >=1024)
+ raw_speeches = [np.random.rand(2048).astype(np.float32), np.random.rand(1024).astype(np.float32)]
+ input_processor = self.processor(random_text, raw_speeches)
+ audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
+
+ sequence_len = audio_mask.shape[1]
+ for batch_idx, speech in enumerate(raw_speeches):
+ raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
+ codebooks = audio_tokenizer(raw_audio).audio_codes.transpose(1, 2)
+
+ pad_len = sequence_len - audio_mask.sum(dim=-1)[batch_idx]
+ for channel_idx, delay in enumerate(self.delay_pattern):
+ # Left padding filled bos, right padding (delay) are pad
+ start_idx = pad_len + delay + 1
+ end_idx = start_idx + codebooks.shape[1]
+
+ encoded_sequence = audio_tokens[batch_idx, :, channel_idx]
+ expected_sequence = torch.ones(size=(sequence_len,)) * self.pad_id
+ expected_sequence[:start_idx] = self.bos_id
+ expected_sequence[start_idx:end_idx] = codebooks[0, :, channel_idx]
+
+ self.assertTrue((encoded_sequence == expected_sequence).all())
+
+ # Just to make sure the masking correctly only ignores bos tokens
+ self.assertTrue((audio_tokens[~audio_mask.bool()] == self.bos_id).all())
+
+ @parameterized.expand([([1, 1],), ([1, 5],), ([2, 4, 6],)])
+ def test_decode_audio(self, audio_lens):
+ feature_extractor = self.get_feature_extractor()
+ audio_tokenizer = self.get_audio_tokenizer()
+
+ random_text = ["Dummy Input"] * len(audio_lens)
+ raw_speeches = [np.random.rand(self.dac_chunk_len * l).astype(np.float32) for l in audio_lens]
+ # we need eos (given if training) to decode properly, also enforced via custom logits processor
+ input_processor = self.processor(random_text, raw_speeches, generation=False)
+ audio_tokens = input_processor["decoder_input_ids"]
+
+ decoded_speeches = self.processor.batch_decode(audio_tokens)
+ for batch_idx, speech in enumerate(raw_speeches):
+ raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
+ codebooks = audio_tokenizer(raw_audio).audio_codes
+
+ decoded_audio = decoded_speeches[batch_idx]
+ expected_audio = audio_tokenizer.decode(audio_codes=codebooks).audio_values
+
+ self.assertTrue((expected_audio == decoded_audio).all())
+ self.assertTrue(decoded_speeches[batch_idx].shape[-1] == audio_lens[batch_idx] * self.dac_chunk_len)
+
+ @parameterized.expand([(1, 2, [0, 1, 4]), (2, 4, [1, 3, 2]), (4, 8, [0, 5, 7])])
+ def test_delay_in_audio(self, bsz, seq_len, delay_pattern):
+ # static functions which are crucial, hence we also test them here
+ build_indices_fn = DiaProcessor.build_indices
+ delay_fn = DiaProcessor.apply_audio_delay
+
+ bos, pad = -2, -1
+ num_channels = len(delay_pattern)
+
+ audio_input = torch.arange(bsz * seq_len * num_channels).view(bsz, seq_len, num_channels)
+ # imitate a delay mask with zeroes
+ audio_input = torch.cat([audio_input, torch.zeros(size=(bsz, max(delay_pattern), num_channels))], dim=1)
+
+ precomputed_idx = build_indices_fn(
+ bsz=bsz,
+ seq_len=seq_len + max(delay_pattern),
+ num_channels=num_channels,
+ delay_pattern=delay_pattern,
+ revert=False,
+ )
+ delayed_audio_out = delay_fn(
+ audio=audio_input,
+ pad_token_id=pad,
+ bos_token_id=bos,
+ precomputed_idx=precomputed_idx,
+ )
+
+ # every channel idx is shifted by delay_pattern[idx]
+ delayed_audio_res = audio_input.clone()
+ for idx, delay in enumerate(delay_pattern):
+ delayed_audio_res[:, :delay, idx] = bos
+ remaining_input = seq_len + max(delay_pattern) - delay
+ delayed_audio_res[:, delay:, idx] = audio_input[:, :remaining_input, idx]
+
+ self.assertTrue((delayed_audio_out == delayed_audio_res).all())
+
+ # we should get back to the original audio we had (when removing the delay pad)
+ bsz, new_seq_len, num_channels = delayed_audio_out.shape
+ precomputed_idx = build_indices_fn(
+ bsz=bsz,
+ seq_len=new_seq_len,
+ num_channels=num_channels,
+ delay_pattern=delay_pattern,
+ revert=True,
+ )
+ reverted_audio_out = delay_fn(
+ audio=delayed_audio_out,
+ pad_token_id=pad,
+ bos_token_id=bos,
+ precomputed_idx=precomputed_idx,
+ )
+
+ reverted_audio_res = audio_input.clone()[:, :seq_len]
+
+ self.assertTrue((reverted_audio_out[:, :seq_len] == reverted_audio_res).all())
diff --git a/tests/models/dia/test_tokenization_dia.py b/tests/models/dia/test_tokenization_dia.py
new file mode 100644
index 0000000000..4ade611f68
--- /dev/null
+++ b/tests/models/dia/test_tokenization_dia.py
@@ -0,0 +1,123 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from transformers.models.dia import DiaTokenizer
+from transformers.testing_utils import slow
+
+from ...test_tokenization_common import TokenizerTesterMixin
+
+
+# Special tokens
+PAD = 0
+S1 = 1
+S2 = 2
+
+
+class DiaTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
+ tokenizer_class = DiaTokenizer
+ test_rust_tokenizer = False
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ tokenizer = DiaTokenizer()
+ tokenizer.save_pretrained(cls.tmpdirname)
+
+ def test_convert_token_and_id(self):
+ """Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
+ token = "i"
+ token_id = 105
+
+ self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
+ self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
+
+ def test_get_vocab(self):
+ vocab_keys = list(self.get_tokenizer().get_vocab().keys())
+
+ self.assertEqual(vocab_keys[PAD], "")
+ self.assertEqual(vocab_keys[S1], "[S1]")
+ self.assertEqual(vocab_keys[S2], "[S2]")
+ self.assertEqual(len(vocab_keys), 256)
+
+ def test_vocab_size(self):
+ # utf-8 == 2**8 == 256
+ self.assertEqual(self.get_tokenizer().vocab_size, 256)
+
+ def test_full_tokenizer(self):
+ tokenizer = DiaTokenizer.from_pretrained(self.tmpdirname)
+
+ tokens = tokenizer.tokenize("Hello, world!")
+ self.assertListEqual(tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
+ ids = tokenizer.convert_tokens_to_ids(tokens)
+ self.assertListEqual(ids, [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33])
+ back_tokens = tokenizer.convert_ids_to_tokens(ids)
+ self.assertListEqual(back_tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
+
+ tokens = tokenizer.tokenize("[S1] Hello [S2] Hello")
+ self.assertListEqual(
+ tokens,
+ ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", ""],
+ )
+ ids = tokenizer.convert_tokens_to_ids(tokens)
+ self.assertListEqual(ids, [S1, 32, 72, 101, 108, 108, 111, 32, S2, 32, 72, 101, 108, 108, 111, PAD])
+ back_tokens = tokenizer.convert_ids_to_tokens(ids)
+ self.assertListEqual(
+ back_tokens, ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", ""]
+ )
+
+ @slow
+ def test_tokenizer_integration(self):
+ # Overwritten as decoding will lead to all single bytes (i.e. characters) while usually the string format is expected
+ expected_encoding = {'input_ids': [[84, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 40, 102, 111, 114, 109, 101, 114, 108, 121, 32, 107, 110, 111, 119, 110, 32, 97, 115, 32, 112, 121, 116, 111, 114, 99, 104, 45, 116, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 97, 110, 100, 32, 112, 121, 116, 111, 114, 99, 104, 45, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 45, 98, 101, 114, 116, 41, 32, 112, 114, 111, 118, 105, 100, 101, 115, 32, 103, 101, 110, 101, 114, 97, 108, 45, 112, 117, 114, 112, 111, 115, 101, 32, 97, 114, 99, 104, 105, 116, 101, 99, 116, 117, 114, 101, 115, 32, 40, 66, 69, 82, 84, 44, 32, 71, 80, 84, 45, 50, 44, 32, 82, 111, 66, 69, 82, 84, 97, 44, 32, 88, 76, 77, 44, 32, 68, 105, 115, 116, 105, 108, 66, 101, 114, 116, 44, 32, 88, 76, 78, 101, 116, 46, 46, 46, 41, 32, 102, 111, 114, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 85, 110, 100, 101, 114, 115, 116, 97, 110, 100, 105, 110, 103, 32, 40, 78, 76, 85, 41, 32, 97, 110, 100, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 71, 101, 110, 101, 114, 97, 116, 105, 111, 110, 32, 40, 78, 76, 71, 41, 32, 119, 105, 116, 104, 32, 111, 118, 101, 114, 32, 51, 50, 43, 32, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 32, 109, 111, 100, 101, 108, 115, 32, 105, 110, 32, 49, 48, 48, 43, 32, 108, 97, 110, 103, 117, 97, 103, 101, 115, 32, 97, 110, 100, 32, 100, 101, 101, 112, 32, 105, 110, 116, 101, 114, 111, 112, 101, 114, 97, 98, 105, 108, 105, 116, 121, 32, 98, 101, 116, 119, 101, 101, 110, 32, 74, 97, 120, 44, 32, 80, 121, 84, 111, 114, 99, 104, 32, 97, 110, 100, 32, 84, 101, 110, 115, 111, 114, 70, 108, 111, 119, 46], [66, 69, 82, 84, 32, 105, 115, 32, 100, 101, 115, 105, 103, 110, 101, 100, 32, 116, 111, 32, 112, 114, 101, 45, 116, 114, 97, 105, 110, 32, 100, 101, 101, 112, 32, 98, 105, 100, 105, 114, 101, 99, 116, 105, 111, 110, 97, 108, 32, 114, 101, 112, 114, 101, 115, 101, 110, 116, 97, 116, 105, 111, 110, 115, 32, 102, 114, 111, 109, 32, 117, 110, 108, 97, 98, 101, 108, 101, 100, 32, 116, 101, 120, 116, 32, 98, 121, 32, 106, 111, 105, 110, 116, 108, 121, 32, 99, 111, 110, 100, 105, 116, 105, 111, 110, 105, 110, 103, 32, 111, 110, 32, 98, 111, 116, 104, 32, 108, 101, 102, 116, 32, 97, 110, 100, 32, 114, 105, 103, 104, 116, 32, 99, 111, 110, 116, 101, 120, 116, 32, 105, 110, 32, 97, 108, 108, 32, 108, 97, 121, 101, 114, 115, 46], [84, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 115, 32, 111, 118, 101, 114, 32, 116, 104, 101, 32, 108, 97, 122, 121, 32, 100, 111, 103, 46]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
+
+ sequences = [
+ "Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
+ "general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
+ "Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained "
+ "models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
+ "BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
+ "conditioning on both left and right context in all layers.",
+ "The quick brown fox jumps over the lazy dog.",
+ ]
+
+ tokenizer_classes = [self.tokenizer_class]
+ if self.test_rust_tokenizer:
+ tokenizer_classes.append(self.rust_tokenizer_class)
+
+ for tokenizer_class in tokenizer_classes:
+ tokenizer = tokenizer_class.from_pretrained("AntonV/Dia-1.6B")
+
+ encoding = tokenizer(sequences)
+ encoding_data = encoding.data
+ self.assertDictEqual(encoding_data, expected_encoding)
+
+ # Byte decoding leads to characters so we need to join them
+ decoded_sequences = [
+ "".join(tokenizer.decode(seq, skip_special_tokens=True)) for seq in encoding["input_ids"]
+ ]
+
+ for expected, decoded in zip(sequences, decoded_sequences):
+ if self.test_sentencepiece_ignore_case:
+ expected = expected.lower()
+ self.assertEqual(expected, decoded)
+
+ @unittest.skip(reason="Dia relies on whole input string due to the byte-level nature.")
+ def test_pretokenized_inputs(self):
+ pass
+
+ @unittest.skip
+ def test_tokenizer_slow_store_full_signature(self):
+ pass
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index a5d9c90068..d3f8456f54 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -4574,6 +4574,11 @@ class ModelTesterMixin:
head_dim = config.head_dim
config.head_dim = max(16, config.head_dim)
+ cross_head_dim = None
+ if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None:
+ cross_head_dim = config.cross_head_dim
+ config.cross_head_dim = max(16, config.cross_head_dim)
+
if (
getattr(config, "hidden_size", None) is not None
and getattr(config, "num_attention_heads", None) is not None
@@ -4588,6 +4593,17 @@ class ModelTesterMixin:
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
+ if (
+ getattr(config, "cross_hidden_size", None) is not None
+ and getattr(config, "cross_num_attention_heads", None) is not None
+ ):
+ cross_head_dim = (
+ cross_head_dim
+ if cross_head_dim is not None
+ else config.cross_hidden_size // config.cross_num_attention_heads
+ )
+ config.cross_hidden_size *= max(16 // cross_head_dim, 1)
+
# Set default attention to flex and update config values
update_config_for_flex(config)
for key in config.sub_configs:
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 46c2bb1a9f..22d6b033af 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -32,6 +32,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
+ # used internally during generation to provide the custom logit processors with their necessary information
+ "DiaConfig": [
+ "delay_pattern",
+ ],
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
# periods and offsets are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
"BambaConfig": [