diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a3c6981861..9ed80cfb0b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -839,6 +839,8 @@ title: CSM - local: model_doc/dac title: dac + - local: model_doc/dia + title: Dia - local: model_doc/encodec title: EnCodec - local: model_doc/fastspeech2_conformer diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index adab8591e2..0a36c7c0a1 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -350,6 +350,10 @@ The following auto classes are available for the following audio tasks. [[autodoc]] AutoModelForTextToWaveform +### AutoModelForAudioTokenization + +[[autodoc]] AutoModelForAudioTokenization + ## Multimodal The following auto classes are available for the following multimodal tasks. diff --git a/docs/source/en/model_doc/dia.md b/docs/source/en/model_doc/dia.md new file mode 100644 index 0000000000..67c4a3be0b --- /dev/null +++ b/docs/source/en/model_doc/dia.md @@ -0,0 +1,162 @@ + + +# Dia + +
+
+ PyTorch + FlashAttention + SDPA +
+
+ +## Overview + +Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs). +It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing. +Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning). + +**Model Architecture:** +Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as +rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while +for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook +tokens and decodes them back into audio. + +## Usage Tips + +### Generation with Text + +```python +from transformers import AutoProcessor, DiaForConditionalGeneration + +torch_device = "cuda" +model_checkpoint = "buttercrab/dia-v1-1.6b" + +text = ["[S1] Dia is an open weights text to dialogue model."] +processor = AutoProcessor.from_pretrained(model_checkpoint) +inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device) + +model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device) +outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s + +# save audio to a file +outputs = processor.batch_decode(outputs) +processor.save_audio(outputs, "example.wav") + +``` + +### Generation with Text and Audio (Voice Cloning) + +```python +from datasets import load_dataset, Audio +from transformers import AutoProcessor, DiaForConditionalGeneration + +torch_device = "cuda" +model_checkpoint = "buttercrab/dia-v1-1.6b" + +ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") +ds = ds.cast_column("audio", Audio(sampling_rate=44100)) +audio = ds[-1]["audio"]["array"] +# text is a transcript of the audio + additional text you want as new audio +text = ["[S1] I know. It's going to save me a lot of money, I hope. [S2] I sure hope so for you."] + +processor = AutoProcessor.from_pretrained(model_checkpoint) +inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device) +prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"]) + +model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device) +outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s + +# retrieve actually generated audio and save to a file +outputs = processor.batch_decode(outputs, audio_prompt_len=prompt_len) +processor.save_audio(outputs, "example_with_audio.wav") +``` + +### Training + +```python +from datasets import load_dataset, Audio +from transformers import AutoProcessor, DiaForConditionalGeneration + +torch_device = "cuda" +model_checkpoint = "buttercrab/dia-v1-1.6b" + +ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") +ds = ds.cast_column("audio", Audio(sampling_rate=44100)) +audio = ds[-1]["audio"]["array"] +# text is a transcript of the audio +text = ["[S1] I know. It's going to save me a lot of money, I hope."] + +processor = AutoProcessor.from_pretrained(model_checkpoint) +inputs = processor( + text=text, + audio=audio, + generation=False, + output_labels=True, + padding=True, + return_tensors="pt" +).to(torch_device) + +model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device) +out = model(**inputs) +out.loss.backward() +``` + + +This model was contributed by [Jaeyong Sung](https://huggingface.co/buttercrab), [Arthur Zucker](https://huggingface.co/ArthurZ), +and [Anton Vlasjuk](https://huggingface.co/AntonV). The original code can be found [here](https://github.com/nari-labs/dia/). + + +## DiaConfig + +[[autodoc]] DiaConfig + +## DiaDecoderConfig + +[[autodoc]] DiaDecoderConfig + +## DiaEncoderConfig + +[[autodoc]] DiaEncoderConfig + +## DiaTokenizer + +[[autodoc]] DiaTokenizer + - __call__ + +## DiaFeatureExtractor + +[[autodoc]] DiaFeatureExtractor + - __call__ + +## DiaProcessor + +[[autodoc]] DiaProcessor + - __call__ + - batch_decode + - decode + +## DiaModel + +[[autodoc]] DiaModel + - forward + +## DiaForConditionalGeneration + +[[autodoc]] DiaForConditionalGeneration + - forward + - generate diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 155491d6d5..54fa7c7e26 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -271,7 +271,6 @@ class PretrainedConfig(PushToHubMixin): self.pad_token_id = kwargs.pop("pad_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None) self.sep_token_id = kwargs.pop("sep_token_id", None) - self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) # task specific arguments diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 8c72279b6e..d4c08e270b 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2975,3 +2975,224 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor): The expected mean g-value for watermarked text. """ return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size)) + + +class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original + `ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall + calculation, e.g. conditioned logits centered, and an additional top k selection + option. + + + + This logits processor is exclusively compatible with + [Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia) + + + + Args: + guidance_scale (float): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. + guidance_top_k (int, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep + the logits of the combined CFG output, but the conditioned output only. + """ + + def __init__(self, guidance_scale: float, guidance_top_k: Optional[int] = None): + if guidance_scale > 1: + self.guidance_scale = guidance_scale + else: + raise ValueError( + "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale " + f"{guidance_scale}." + ) + + self.guidance_top_k = guidance_top_k + if self.guidance_top_k is not None and self.guidance_top_k < 1: + raise ValueError( + f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}" + ) + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # simple check to make sure we have compatible batch sizes between our + # logits scores (cond + uncond) and input ids (cond only) + if scores.shape[0] != 2 * input_ids.shape[0]: + raise ValueError( + f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to " + f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got " + f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids." + ) + # Base CFG with center on cond_logits + unguided_bsz = scores.shape[0] // 2 + cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0) + scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale + + # Optional CFG top k filtering + if self.guidance_top_k is not None: + # Create top k based on the combined CFG output + _, top_k_indices = torch.topk(scores_processed, k=self.guidance_top_k, dim=-1) + top_k_mask = torch.ones_like(scores_processed, dtype=torch.bool) + top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False) + # Only return conditioned logits with top k + scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf")) + + return scores_processed + + +class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor): + r"""Specialized processor that ensures certain properties around EOS sampling: + 1. Only channel 0 can generate EOS + 2. If channel 0 has EOS with highest logit, it will be the only candidate + 3. If channel 0 has EOS not with highest logit, it will be suppressed + + 2. and 3. are especially important in contexts where we allow sampling to guarantee the + respective tokens to be (not) sampled. + + + + This logits processor is exclusively compatible with + [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia). + + + + Args: + num_channels (`int`): + Number of audio codebooks. Simplifies access to the first channel on the logits. + eos_token_id (`int`): + The id of *end-of-sequence* token. + """ + + def __init__(self, num_channels: int, eos_token_id: int): + if num_channels < 1: + raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.") + if eos_token_id < 1: + raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.") + + self.num_channels = num_channels + self.eos_id = eos_token_id + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # Reshape for easier channel indexing [B, C, V] + scores = scores.reshape(-1, self.num_channels, scores.shape[-1]) + + # EOS filter + # 1. Condition: Only the first channel can generate the EOS token + # Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...) + # (Assumes them to be greater than audio eos token position) + scores[:, 1:, self.eos_id :] = torch.full_like( + scores[:, 1:, self.eos_id :], + fill_value=-float("inf"), + ) + scores[:, 0, self.eos_id + 1 :] = torch.full_like( + scores[:, 0, self.eos_id + 1 :], + fill_value=-float("inf"), + ) + + # 2+3 Conditions: Force/Suppress EOS if (not) highest logit + # Reshape back to original shape + scores = scores.view(-1, scores.shape[-1]) + + # Sample highest tokens + top_logit_indices = torch.argmax(scores, dim=-1) + + # 2. Force EOS + eos_highest_mask = top_logit_indices == self.eos_id + mask_eos_highest = torch.zeros_like(scores, dtype=torch.bool) + mask_eos_highest[eos_highest_mask, : self.eos_id] = True + scores = scores.masked_fill(mask_eos_highest, -float("inf")) + + # 3. Suppress EOS + eos_not_highest_mask = top_logit_indices != self.eos_id + mask_eos_unless_highest = torch.zeros_like(scores, dtype=torch.bool) + mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True + scores = scores.masked_fill(mask_eos_unless_highest, -float("inf")) + + return scores + + +class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor): + r"""Special logits processor to handle the generation of the EOS token in Dia. + This is due to the fact that Dia does not allow the generation of EOS in all + channels except the first channel (C0). + + Hence, based on the delay pattern, an EOS is forced after the respective delays + in the channels. For example, if the delay pattern is [0, 2, 3, 4]: + + s s+1 s+2 s+3 s+4 s+5 ... + | | | | | | + C0: EOS PAD PAD PAD PAD PAD ... + C1: x x EOS PAD PAD PAD ... + C2: x x x EOS PAD PAD ... + C3: x x x x EOS PAD ... + + If the first channel generated EOS at step s, channels Cx are forced to generate + theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are + handled by the `EosTokenCriteria` when an EOS has been detected. + + + + This logits processor is exclusively compatible with + [Dia](https://huggingface.co/docs/transformers/en/model_doc/dia). + + + + Args: + delay_pattern (`List[int]`): + The delays per channel in the audio codebooks. + eos_token_id (`int`): + The id of *end-of-sequence* token. + max_generation_len (`int`): + The max sequence length that can be generated. + device (`str`, *optional*, defaults to `"cpu"`): + The device to allocate the tensors on. + """ + + def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int, device: str = "cpu"): + self.num_channels = len(delay_pattern) + # Update during first iteration + self.active_batches = None + self.delay_pattern = torch.tensor(delay_pattern, device=device, dtype=torch.int)[None, :] + self.eos_token_id = eos_token_id + self.max_generation_len = max_generation_len - max(delay_pattern) - 1 + self.device = device + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # Reshape for easier channel indexing [B, C, V] + scores = scores.reshape(-1, self.num_channels, scores.shape[-1]) + + # Initialize / expand values on first iteration + if self.active_batches is None: + self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1) + self.active_batches = torch.zeros(size=(scores.shape[0],), device=self.device, dtype=torch.bool) + + # Check if eos has been generated in any batch + channel_generated_eos = torch.argmax(scores, dim=-1)[:, 0] == self.eos_token_id + # Check if max len has been reached + reached_max_len = input_ids.shape[1] == self.max_generation_len + + # Update active batches + self.active_batches |= channel_generated_eos + self.active_batches |= reached_max_len + + # Find channels that need to force eos + forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0) + # Use indexing to avoid issues on all `False` by having empty tensors in that case + idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True) + + # Force eos if delay is kicking in + scores[idx_bsz, idx_channel, :] = -float("inf") + scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0 + + # Reshape back to [B * C, V] + scores = scores.reshape(-1, scores.shape[-1]) + + # Update amount of delay left for each channel + self.delay_pattern -= self.active_batches[:, None].int() + + return scores diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a5d1be345d..ea2bd32aa3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -26,6 +26,7 @@ import re import shutil import tempfile import warnings +from abc import abstractmethod from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager @@ -5884,3 +5885,26 @@ class AttentionInterface(GeneralInterface): # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() + + +class PreTrainedAudioTokenizerBase(PreTrainedModel): + """ + Class that additionally defines the behavior of any `audio_tokenizer` to be added. + Characteristic for any of them: + 1. Encode raw audio into discrete audio codebooks (with x channels) + 2. Decode from discrete audio codebooks back to raw audio + It is possible that they can decode in different ways given a different representation + but they are forced to support 2. nonetheless, e.g. see `DAC`. + """ + + @abstractmethod + def encode(self, input_values: torch.Tensor, *args, **kwargs): + """ + Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels) + """ + pass + + @abstractmethod + def decode(self, audio_codes: torch.Tensor, *args, **kwargs): + """Decode from discrete audio codebooks back to raw audio""" + pass diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3c0e649f8a..7b2332d89f 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -88,6 +88,7 @@ if TYPE_CHECKING: from .depth_anything import * from .depth_pro import * from .detr import * + from .dia import * from .dialogpt import * from .diffllama import * from .dinat import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8d2109759d..71ad6eaade 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -106,6 +106,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("depth_pro", "DepthProConfig"), ("deta", "DetaConfig"), ("detr", "DetrConfig"), + ("dia", "DiaConfig"), ("diffllama", "DiffLlamaConfig"), ("dinat", "DinatConfig"), ("dinov2", "Dinov2Config"), @@ -478,6 +479,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("depth_pro", "DepthPro"), ("deta", "DETA"), ("detr", "DETR"), + ("dia", "Dia"), ("dialogpt", "DialoGPT"), ("diffllama", "DiffLlama"), ("dinat", "DiNAT"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index cf806f39a6..d54ca4b0f5 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -55,6 +55,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("deformable_detr", "DeformableDetrFeatureExtractor"), ("deit", "DeiTFeatureExtractor"), ("detr", "DetrFeatureExtractor"), + ("dia", "DiaFeatureExtractor"), ("dinat", "ViTFeatureExtractor"), ("donut-swin", "DonutFeatureExtractor"), ("dpt", "DPTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 51a3c3fbbc..add9d09b0e 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -99,6 +99,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("depth_pro", "DepthProModel"), ("deta", "DetaModel"), ("detr", "DetrModel"), + ("dia", "DiaModel"), ("diffllama", "DiffLlamaModel"), ("dinat", "DinatModel"), ("dinov2", "Dinov2Model"), @@ -472,6 +473,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ("data2vec-text", "Data2VecTextForMaskedLM"), ("deberta", "DebertaForMaskedLM"), ("deberta-v2", "DebertaV2ForMaskedLM"), + ("dia", "DiaForConditionalGeneration"), ("distilbert", "DistilBertForMaskedLM"), ("electra", "ElectraForMaskedLM"), ("encoder-decoder", "EncoderDecoderModel"), @@ -1059,6 +1061,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ + ("dia", "DiaForConditionalGeneration"), ("granite_speech", "GraniteSpeechForConditionalGeneration"), ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"), ("moonshine", "MoonshineForConditionalGeneration"), @@ -1629,6 +1632,12 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict( + [ + ("dac", "DacModel"), + ] +) + MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) @@ -1737,6 +1746,8 @@ MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping( MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) +MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES) + class AutoModelForMaskGeneration(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING @@ -2034,6 +2045,15 @@ class AutoModelForMaskedImageModeling(_BaseAutoModelClass): AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") +class AutoModelForAudioTokenization(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING + + +AutoModelForAudioTokenization = auto_class_update( + AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks" +) + + class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod def from_config(cls, config): @@ -2059,6 +2079,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead): __all__ = [ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING", "MODEL_FOR_AUDIO_XVECTOR_MAPPING", "MODEL_FOR_BACKBONE_MAPPING", "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", @@ -2106,6 +2127,7 @@ __all__ = [ "AutoBackbone", "AutoModelForAudioClassification", "AutoModelForAudioFrameClassification", + "AutoModelForAudioTokenization", "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 372c0b249b..bccfe3e6d5 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -61,6 +61,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("clvp", "ClvpProcessor"), ("colpali", "ColPaliProcessor"), ("colqwen2", "ColQwen2Processor"), + ("dia", "DiaProcessor"), ("emu3", "Emu3Processor"), ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 50a1a2732c..0456e1945c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -177,6 +177,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]]( "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("dia", ("DiaTokenizer", None)), ( "diffllama", ( diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index b3bca5b63e..191e7af89e 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -23,7 +23,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import PreTrainedAudioTokenizerBase from ...utils import ModelOutput, auto_docstring from .configuration_dac import DacConfig @@ -471,7 +471,7 @@ class DacEncoder(nn.Module): @auto_docstring -class DacPreTrainedModel(PreTrainedModel): +class DacPreTrainedModel(PreTrainedAudioTokenizerBase): config_class = DacConfig base_model_prefix = "dac" main_input_name = "input_values" diff --git a/src/transformers/models/dia/__init__.py b/src/transformers/models/dia/__init__.py new file mode 100644 index 0000000000..d738fbc087 --- /dev/null +++ b/src/transformers/models/dia/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dia import * + from .feature_extraction_dia import * + from .generation_dia import * + from .modeling_dia import * + from .processing_dia import * + from .tokenization_dia import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/dia/configuration_dia.py b/src/transformers/models/dia/configuration_dia.py new file mode 100644 index 0000000000..90ace73b3c --- /dev/null +++ b/src/transformers/models/dia/configuration_dia.py @@ -0,0 +1,376 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dia model configuration""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DiaEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia + encoder according to the specified arguments, defining the encoder architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + Number of key and value heads for each attention layer in the Transformer encoder. + head_dim (`int`, *optional*, defaults to 128): + Dimensionality of the attention head. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + vocab_size (`int`, *optional*, defaults to 256): + Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DiaModel`]. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"swish"` and `"gelu_new"` are supported. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "dia_encoder" + + def __init__( + self, + max_position_embeddings: int = 1024, + num_hidden_layers: int = 12, + hidden_size: int = 1024, + num_attention_heads: int = 16, + num_key_value_heads: int = 16, + head_dim: int = 128, + intermediate_size: int = 4096, + norm_eps: float = 1e-5, + vocab_size: int = 256, + hidden_act: str = "silu", + rope_theta: float = 10000.0, + rope_scaling: Optional[dict] = None, + initializer_range: float = 0.02, + **kwargs, + ): + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.norm_eps = norm_eps + self.vocab_size = vocab_size + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + self.initializer_range = initializer_range + super().__init__(**kwargs) + + +class DiaDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia + decoder according to the specified arguments, defining the decoder architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + max_position_embeddings (`int`, *optional*, defaults to 3072): + The maximum sequence length that this model might ever be used with. + num_hidden_layers (`int`, *optional*, defaults to 18): + Number of hidden layers in the Transformer decoder. + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the decoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + Number of key and value heads for each attention layer in the Transformer decoder. + head_dim (`int`, *optional*, defaults to 128): + Dimensionality of the attention head. + cross_num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each cross-attention layer in the Transformer decoder. + cross_head_dim (`int`, *optional*, defaults to 128): + Dimensionality of the cross-attention head. + cross_num_key_value_heads (`int`, *optional*, defaults to 16): + Number of key and value heads for each cross-attention layer in the Transformer decoder. + cross_hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the cross-attention layers. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + vocab_size (`int`, *optional*, defaults to 1028): + Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DiaModel`]. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`, + `"swish"` and `"gelu_new"` are supported. + num_channels (`int`, *optional*, defaults to 9): + Number of channels for the Dia decoder. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Indicating that this model is part of an encoder-decoder architecture. + """ + + model_type = "dia_decoder" + + def __init__( + self, + max_position_embeddings: int = 3072, + num_hidden_layers: int = 18, + hidden_size: int = 2048, + intermediate_size: int = 8192, + num_attention_heads: int = 16, + num_key_value_heads: int = 4, + head_dim: int = 128, + cross_num_attention_heads: int = 16, + cross_head_dim: int = 128, + cross_num_key_value_heads: int = 16, + cross_hidden_size: int = 1024, + norm_eps: float = 1e-5, + vocab_size: int = 1028, + hidden_act: str = "silu", + num_channels: int = 9, + rope_theta: float = 10000.0, + rope_scaling: Optional[dict] = None, + initializer_range: float = 0.02, + use_cache: bool = True, + is_encoder_decoder: bool = True, + **kwargs, + ): + self.max_position_embeddings = max_position_embeddings + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.cross_num_key_value_heads = cross_num_key_value_heads + self.cross_num_attention_heads = cross_num_attention_heads + self.cross_head_dim = cross_head_dim + self.cross_hidden_size = cross_hidden_size + self.norm_eps = norm_eps + self.vocab_size = vocab_size + self.hidden_act = hidden_act + self.num_channels = num_channels + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + self.initializer_range = initializer_range + self.use_cache = use_cache + super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) + + +class DiaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a + Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + encoder_config (`DiaEncoderConfig`, *optional*): + Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used. + decoder_config (`DiaDecoderConfig`, *optional*): + Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Indicating that this model uses an encoder-decoder architecture. + pad_token_id (`int`, *optional*, defaults to 1025): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1024): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 1026): + Beginning of stream token id. + delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`): + The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example: + + ```python + >>> from transformers import DiaConfig, DiaModel + + >>> # Initializing a DiaConfig with default values + >>> configuration = DiaConfig() + + >>> # Initializing a DiaModel (with random weights) from the configuration + >>> model = DiaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dia" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig} + + def __init__( + self, + encoder_config: Optional[DiaEncoderConfig] = None, + decoder_config: Optional[DiaDecoderConfig] = None, + norm_eps: float = 1e-5, + is_encoder_decoder: bool = True, + pad_token_id: int = 1025, + eos_token_id: int = 1024, + bos_token_id: int = 1026, + delay_pattern: Optional[list[int]] = None, + initializer_range: float = 0.02, + use_cache: bool = True, + **kwargs, + ): + if isinstance(encoder_config, dict): + encoder_config = DiaEncoderConfig(**encoder_config) + if isinstance(decoder_config, dict): + decoder_config = DiaDecoderConfig(**decoder_config) + self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig() + self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig() + self.norm_eps = norm_eps + self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15] + self.initializer_range = initializer_range + self.use_cache = use_cache + + assert self.decoder_config.num_channels == len(self.delay_pattern), ( + "Number of channels must match delay pattern length." + ) + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + def get_text_config(self, decoder=False): + """Defaulting to audio config as it's the decoder in this case which is usually the text backbone""" + return self.decoder_config + + +__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"] diff --git a/src/transformers/models/dia/convert_dia_to_hf.py b/src/transformers/models/dia/convert_dia_to_hf.py new file mode 100644 index 0000000000..3a33860f6b --- /dev/null +++ b/src/transformers/models/dia/convert_dia_to_hf.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Converts a Dia model in Nari Labs format to Hugging Face format.""" + +import argparse +import os +import re + +import torch +from huggingface_hub import snapshot_download +from safetensors.torch import load_file + +from transformers import ( + DacModel, + DiaConfig, + DiaFeatureExtractor, + DiaForConditionalGeneration, + DiaProcessor, + DiaTokenizer, + GenerationConfig, +) +from transformers.utils.import_utils import _is_package_available + + +# Provide just the list of layer keys you want to fix +shape_mappings = [ + "encoder.layers.*.mlp.gate_up_proj.weight", + "encoder.layers.*.mlp.down_proj.weight", + "encoder.layers.*.self_attention.q_proj.weight", + "encoder.layers.*.self_attention.k_proj.weight", + "encoder.layers.*.self_attention.v_proj.weight", + "encoder.layers.*.self_attention.o_proj.weight", + "decoder.layers.*.mlp.gate_up_proj.weight", + "decoder.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.q_proj.weight", + "decoder.layers.*.self_attention.k_proj.weight", + "decoder.layers.*.self_attention.v_proj.weight", + "decoder.layers.*.self_attention.o_proj.weight", + "decoder.layers.*.cross_attention.q_proj.weight", + "decoder.layers.*.cross_attention.k_proj.weight", + "decoder.layers.*.cross_attention.v_proj.weight", + "decoder.layers.*.cross_attention.o_proj.weight", + "decoder.logits_dense.weight", +] + +# Provide renamings here +rename_mapping = { + "mlp.wo": "mlp.down_proj", + "mlp.wi_fused": "mlp.gate_up_proj", +} + + +def get_generation_config(config): + model_generation_config = GenerationConfig.from_model_config(config) + model_generation_config._from_model_config = False + model_generation_config.do_sample = True + model_generation_config.top_k = 45 + model_generation_config.top_p = 0.95 + model_generation_config.temperature = 1.2 + model_generation_config.guidance_scale = 3.0 + model_generation_config.max_length = 3072 # Decoder max length + + return model_generation_config + + +def convert_dia_model_to_hf(checkpoint_path, verbose=False): + """ + Converts a Dia model in Nari Labs format to Hugging Face format. + Args: + checkpoint_path (`str`): + Path to the downloaded checkpoints. + verbose (`bool`, *optional*) + Whether to print information during conversion. + """ + # Download from HF Hub if checkpoint_path is None + checkpoint_path = snapshot_download(repo_id=checkpoint_path, allow_patterns=["*.pth", "*.safetensors"]) + print(f"Downloaded checkpoint from Hugging Face Hub: {checkpoint_path}") + + # Initialize base model with default config == 1.6B model + with torch.device("meta"): + hf_model = DiaForConditionalGeneration(config=DiaConfig()) + hf_model_dict = hf_model.state_dict() + hf_model_keys = hf_model_dict.keys() + + # Iterate through dir to catch all respective files - prefers safetensors but allows pt + files = os.listdir(checkpoint_path) + for file in files: + if file.endswith(".safetensors"): + load_function = load_file + elif file.endswith(".pth"): + load_function = torch.load + checkpoint_path = os.path.join(checkpoint_path, files[0]) + nari_state_dict = load_function(checkpoint_path, "cpu") + + # Conversion starts here + converted_state_dict = {} + embeddings = {} + for key, tensor in nari_state_dict.items(): + # add prefix + key = "model." + key + + # rename some weights + for original, rename in rename_mapping.items(): + if original in key: + key = re.sub(original, rename, key) + + # decoder multi channel + if "embeddings" in key: + embeddings_key = key.rsplit(".", 2)[0] + ".embed.weight" + if embeddings_key in embeddings: + embeddings[embeddings_key] += [tensor] + else: + embeddings[embeddings_key] = [tensor] + continue + elif re.sub(r"\d+", "*", key).removeprefix("model.") in shape_mappings: + # add exception to the head + if "logits_dense" in key: + key = re.sub("decoder.logits_dense", "logits_dense", key).removeprefix("model.") + + # dense general + if key in hf_model_keys: + tensor_shape = tensor.shape + target_shape = hf_model_dict[key].shape + try: + tensor = tensor.reshape(target_shape[1], target_shape[0]).T + if verbose: + print(f"{key}: transpose reshaped from {tensor_shape} to {target_shape}") + except Exception as e: + print(f"WARNING: Could not reshape {key}: {e}") + + converted_state_dict[key] = tensor + + # Combining the embeddings as last step + embeddings = {k: torch.cat(v, dim=0) for k, v in embeddings.items()} + converted_state_dict.update(embeddings) + + # Load converted weights into HF model + hf_model.load_state_dict(converted_state_dict, assign=True) + + # Overwrite generation config + hf_model.generation_config = get_generation_config(DiaConfig()) + + return hf_model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # # Required parameters + parser.add_argument( + "--checkpoint_path", type=str, default="nari-labs/Dia-1.6B", help="Path to the downloaded checkpoints" + ) + parser.add_argument( + "--pytorch_dump_folder_path", default="AntonV/Dia-1.6B", type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--convert_preprocessor", + type=bool, + default=True, + help="Whether or not the preprocessor (tokenizer + feature extractor) should be converted along with the model.", + ) + parser.add_argument( + "--verbose", + type=bool, + default=True, + help="Whether or not to log information during conversion.", + ) + args = parser.parse_args() + + model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose) + if args.convert_preprocessor: + try: + if not _is_package_available("tiktoken"): + raise ModuleNotFoundError( + """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer""" + ) + except Exception as e: + print(e) + else: + processor = DiaProcessor( + DiaFeatureExtractor(sampling_rate=44100, hop_length=512), + DiaTokenizer(), + DacModel.from_pretrained("descript/dac_44khz"), + ) + processor.save_pretrained(args.pytorch_dump_folder_path) + + model.save_pretrained(args.pytorch_dump_folder_path) + print(f"Saved converted checkpoint to {args.pytorch_dump_folder_path}") diff --git a/src/transformers/models/dia/feature_extraction_dia.py b/src/transformers/models/dia/feature_extraction_dia.py new file mode 100644 index 0000000000..0d03ceff37 --- /dev/null +++ b/src/transformers/models/dia/feature_extraction_dia.py @@ -0,0 +1,183 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Dia""" + +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class DiaFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs an Dia feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used for padding. + hop_length (`int`, *optional*, defaults to 512): + Overlap length between successive windows. + """ + + model_input_names = ["input_values", "n_quantizers"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 16000, + padding_value: float = 0.0, + hop_length: int = 512, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.hop_length = hop_length + + def __call__( + self, + raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # convert stereo to mono if necessary, unique to Dia + for idx, example in enumerate(raw_audio): + if self.feature_size == 2 and example.ndim == 2: + raw_audio[idx] = np.mean(example, -1) + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.ndim != 1: # note the conversion before + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + input_values = BatchFeature({"input_values": raw_audio}) + + # temporarily treat it as if we were mono as we also convert stereo to mono + origingal_feature_size = self.feature_size + self.feature_size = 1 + + # normal padding on batch + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=True, + pad_to_multiple_of=self.hop_length, + ) + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + # rewrite back to original feature size + self.feature_size = origingal_feature_size + + return padded_inputs + + +__all__ = ["DiaFeatureExtractor"] diff --git a/src/transformers/models/dia/generation_dia.py b/src/transformers/models/dia/generation_dia.py new file mode 100644 index 0000000000..0ca5998bf2 --- /dev/null +++ b/src/transformers/models/dia/generation_dia.py @@ -0,0 +1,464 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist + +from ...generation.logits_process import ( + DiaClassifierFreeGuidanceLogitsProcessor, + DiaEOSChannelFilterLogitsProcessor, + DiaEOSDelayPatternLogitsProcessor, + LogitsProcessorList, + TemperatureLogitsWarper, +) +from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation.streamers import BaseStreamer +from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_utils import PreTrainedModel +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DiaGenerationMixin(GenerationMixin): + # Indicates CFG which needs preparation to be properly handled by repeats + _uses_cfg = None + + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: Optional[int] = None, + encoder_input_ids: torch.LongTensor = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + device: Optional[str] = None, + model_kwargs: Optional[dict[str, Any]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + ) -> LogitsProcessorList: + # Need either custom order or custom processor instead + # (Temporarily disabling those for the super function) + original_guidance_scale = generation_config.guidance_scale + original_temperature = generation_config.temperature + generation_config.guidance_scale = None + generation_config.temperature = None + + # Get base processors and those we can integrate easily + custom_processors = LogitsProcessorList() + + if original_temperature is not None and original_temperature != 1.0: + custom_processors.append(TemperatureLogitsWarper(original_temperature)) + + custom_processors.append( + DiaEOSChannelFilterLogitsProcessor( + num_channels=len(self.config.delay_pattern), + eos_token_id=self.config.eos_token_id, + ) + ) + + merged_processors = super()._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=encoder_input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=custom_processors, + device=device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # Custom processors we need at specific positions + if original_guidance_scale is not None and original_guidance_scale != 1: + cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor( + guidance_scale=original_guidance_scale, + guidance_top_k=generation_config.top_k, + ) + merged_processors.insert(0, cfg_processor) + + merged_processors.append( + DiaEOSDelayPatternLogitsProcessor( + delay_pattern=self.config.delay_pattern, + eos_token_id=self.config.eos_token_id, + max_generation_len=generation_config.max_length, + device=device, + ) + ) + + # Enable temporarily disabled values back + generation_config.guidance_scale = original_guidance_scale + generation_config.temperature = original_temperature + + return merged_processors + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict + ) -> tuple[GenerationConfig, dict]: + generation_config, model_kwargs = super()._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) + + # We allow generation up to max length + max delay pattern + # (will revert back to max length after generation) + generation_config.max_length += max(self.config.delay_pattern) + + # Internal flag to indicate CFG that needs to prepare unconditioned input + self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1 + + return generation_config, model_kwargs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[dict[str, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]: + inputs, input_name, model_kwargs = super()._prepare_model_inputs( + inputs=inputs, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) + + # If CFG is requested we fill in the unconditioned parts + if self._uses_cfg: + unconditioned_inputs = torch.zeros_like(inputs) + inputs = torch.cat([inputs, unconditioned_inputs], dim=0) + + if model_kwargs.get("attention_mask", None) is not None: + model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1) + + return inputs, input_name, model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: dict[str, torch.Tensor], + decoder_start_token_id: torch.Tensor, + device: Optional[torch.device] = None, + ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out + decoder_input_ids = decoder_attention_mask = None + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + if model_kwargs is not None and "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs.pop("decoder_attention_mask") + + # We allow generating without preparation (no proper delay) but discourage it + if decoder_input_ids is None or decoder_attention_mask is None: + logger.warning_once( + "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:" + f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}." + f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation." + ) + + num_channels = self.config.decoder_config.num_channels + real_batch_size = batch_size // 2 if self._uses_cfg else batch_size + + if decoder_input_ids is None: + decoder_input_ids = torch.full( + (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device + ) + + decoder_attention_mask = torch.ones( + size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device + ) + + # 2. Determine the valid input and what works as mask within the input + delay_mask = decoder_input_ids.long() + valid_input_size = ( + decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max() + ) + decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long() + decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long() + + # 3. Overwrite into model kwargs + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + model_kwargs["decoder_delay_mask"] = delay_mask + + return decoder_input_ids, model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + encoder_outputs=None, # Using this to easily get the batch size + decoder_delay_mask=None, + **kwargs, + ): + # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape + batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0] + input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2) + + # Base method handles most things except CFG and the delay pattern mask + model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs) + + # Post processing for CFG and overwriting via delay pattern mask + # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask) + model_inputs["decoder_input_ids"] = self.apply_delay_mask( + input_ids, self.config.pad_token_id, decoder_delay_mask + ) + + # Depending on cache usage we need to pass all or just one + if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0: + model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :] + + # Be compile friendly + model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous() + + # 2. Apply CFG duplication if needed + if self._uses_cfg: + for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]: + if model_inputs.get(key, None) is not None: + # double first dimension and keep everything else the same + repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1)) + model_inputs[key] = model_inputs[key].repeat(*repeat_pattern) + + return model_inputs + + @staticmethod + def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor: + if delay_mask is None: + return input_ids + + mask_len = min(input_ids.shape[1], delay_mask.shape[1]) + valid_mask = delay_mask[:, :mask_len, :] + valid_input = input_ids[:, :mask_len, :] + + # Overwrite the respective parts of the input + input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask) + + return input_ids + + def _main_generate_loop( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, + custom_generate: Optional[str] = None, + **kwargs, + ): + # ********** mostly taken from main generate function up to calling the different methods (see NOTE) ********** + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, use_model_defaults, **kwargs + ) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + # 3. Define model inputs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # 4. Define other model kwargs + if "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + device=inputs_tensor.device, + ) + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + # NOTE: incorrect `input_ids.shape[1]` previously + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs: + model_kwargs["logits_to_keep"] = 1 + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + max_cache_length = generation_config.max_length - 1 + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + + # 8. determine generation mode + generation_mode = generation_config.get_generation_mode(assistant_model) + + if streamer is not None and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + # 9. prepare logits processors and stopping criteria + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + ) + + # Set model_kwargs `use_cache` so we can use it later in forward runs + model_kwargs["use_cache"] = generation_config.use_cache + # ******************* taken from main generate function up to calling the different methods ******************* + + # Prepare inner 2D logic in generation loop + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) + + # 10. go into different generation modes + if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + if generation_config.num_return_sequences > 1: + raise ValueError("`num_return_sequences>1` is incompatible with Dia.") + + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + return self._sample( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + else: + raise ValueError( + "Got incompatible mode for generation, should be one of greedy or sampling. " + "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + use_model_defaults: Optional[bool] = None, + custom_generate: Optional[str] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + # We expect the initial input ids to be the complete mask (delayed input) + delay_mask = kwargs.get("decoder_input_ids", None) + if delay_mask is not None: + delay_mask = delay_mask.clone() + + output = self._main_generate_loop( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + use_model_defaults=use_model_defaults, + custom_generate=custom_generate, + **kwargs, + ) + + return_dict_in_generate = not isinstance(output, torch.Tensor) + + if return_dict_in_generate: + output_sequences = output.sequences + else: + output_sequences = output + + # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels) + num_channels = self.config.decoder_config.num_channels + bsz = output_sequences.shape[0] // num_channels + output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2) + + # Apply delay mask + output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask) + + if return_dict_in_generate: + output.sequences = output_sequences + else: + output = output_sequences + + return output diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py new file mode 100644 index 0000000000..19cac3e8c3 --- /dev/null +++ b/src/transformers/models/dia/modeling_dia.py @@ -0,0 +1,963 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dia/modular_dia.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dia.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig +from .generation_dia import DiaGenerationMixin + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class DiaPreTrainedModel(PreTrainedModel): + config_class = DiaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True + main_input_name = "input_ids" + _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DiaRMSNorm): + module.weight.data.fill_(1.0) + + +class DiaMultiChannelEmbedding(nn.Module): + """In order to efficiently compute the audio embedding from the 9 different channels, + we vectorize the embedding process by using a single embedding layer and an offset. + Example: + - num_embeds = 4 + - vocab_size = 8 + - num_channels = 3 + We would have offsets = [0, 8, 16] + If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8], + then tokens = audio_codes + offsets + = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24] + This allows us to use a single embedding layer for all channels. + """ + + def __init__(self, config: DiaDecoderConfig): + super().__init__() + self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size) + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,) + self.register_buffer("offsets", offsets, persistent=False) + + def forward(self, audio_codes: torch.Tensor) -> torch.Tensor: + tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1) + embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size) + return embeds.sum(dim=2) + + +class DiaMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +@use_kernel_forward_from_hub("RMSNorm") +class DiaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DiaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DiaRotaryEmbedding(nn.Module): + def __init__(self, config: DiaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class DiaSelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = is_causal + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.cross_hidden_size = config.cross_hidden_size + self.num_heads = self.config.cross_num_attention_heads + self.num_key_value_heads = self.config.cross_num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = config.cross_head_dim + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False + if past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] + value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + else: + key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + + if past_key_values is not None: + # save all states to the cache + key_states, value_states = past_key_values.cross_attention_cache.update( + key_states, + value_states, + self.layer_idx, + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape((*input_shape, -1)).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiaEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaEncoderConfig, layer_idx: int): + super().__init__() + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False) + self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.post_sa_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights + + +class DiaEncoder(DiaPreTrainedModel): + def __init__(self, config: DiaEncoderConfig): + super().__init__(config) + self.config = config + + self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.rotary_embeddings = DiaRotaryEmbedding(config) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[BaseModelOutput, tuple]: + hidden_states = self.embedding(input_ids) + + # RoPE + # Note: We expect right padding and hence always generate + # the position ids on the fly to reduce preparation overhead + position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :] + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + encoder_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class DiaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True) + self.cross_attention = DiaCrossAttention(config, layer_idx) + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + self_attn_cache = past_key_values + if isinstance(self_attn_cache, EncoderDecoderCache): + self_attn_cache = self_attn_cache.self_attention_cache + + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings, + attention_mask, + # Needs to be an arg in order to function properly + # on inplace operations to be carried (e.g. compile) + self_attn_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.pre_ca_norm(hidden_states) + cross_states, cross_attn_weights = self.cross_attention( + normed_states, + encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = residual + cross_states + + residual = hidden_states + normed_states = self.pre_mlp_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights, cross_attn_weights + + +class DiaDecoder(DiaPreTrainedModel): + """Transformer Decoder Stack using DenseGeneral.""" + + def __init__(self, config: DiaDecoderConfig): + super().__init__(config) + self.num_channels = config.num_channels + self.vocab_size = config.vocab_size + self.embeddings = DiaMultiChannelEmbedding(config) + self.rotary_embeddings = DiaRotaryEmbedding(config) + self.layers = nn.ModuleList( + [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`): + The original `decoder_input_ids` in 3D shape to facilitate more efficient computations. + + [What are input IDs?](../glossary#input-ids) + """ + + batch_size, seq_length = input_ids.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=input_ids.device + ) + if position_ids is None: + position_ids = cache_position[None, :] + + # RoPE + hidden_states = self.embeddings(input_ids) + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device) + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + hidden_states.shape[:2], + hidden_states, + ) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + position_embeddings, + attention_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns = all_self_attns + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +@auto_docstring( + custom_intro=""" + The bare Dia model outputting raw hidden-states without any specific head on top. + """ +) +class DiaModel(DiaPreTrainedModel): + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.encoder = DiaEncoder(config.encoder_config) + self.decoder = DiaDecoder(config.decoder_config) + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + """ + + if input_ids is None and encoder_outputs is None: + raise ValueError( + "You should either provide text ids or the cached text encodings. Neither has been found." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if self.is_gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # On default we initialize the decoder with bos tokens if nothing has been provided + bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels) + if decoder_input_ids is None: + decoder_input_ids = torch.full( + size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device + ) + # Ensure 3D + if decoder_input_ids.ndim == 2: + decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + position_ids=decoder_position_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs[0], + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top. + """ +) +class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin): + base_model_prefix = "model" + + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.model = DiaModel(config) + + self.num_channels = config.decoder_config.num_channels + self.vocab_size = config.decoder_config.vocab_size + self.logits_dense = nn.Linear( + config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False + ) + self.loss_type = "ForMaskedLM" + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in + `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100` + are ignored (masked). + """ + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_position_ids=decoder_position_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + last_hidden_state = outputs[0] + batch_size = last_hidden_state.shape[0] + # 3D <-> 2D makes it necessary to prioritize channel dim + audio_logits = ( + self.logits_dense(last_hidden_state) + .view((batch_size, -1, self.num_channels, self.vocab_size)) + .transpose(1, 2) + .contiguous() + .view(batch_size * self.num_channels, -1, self.vocab_size) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=audio_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"] diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py new file mode 100644 index 0000000000..fe437fde84 --- /dev/null +++ b/src/transformers/models/dia/modular_dia.py @@ -0,0 +1,789 @@ +# coding=utf-8 +# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Dia model.""" + +from typing import Callable, Optional, Union + +import torch +from torch import nn + +from ...cache_utils import DynamicCache, EncoderDecoderCache +from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaRMSNorm, + LlamaRotaryEmbedding, + eager_attention_forward, +) +from ..phi3.modeling_phi3 import Phi3MLP +from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig +from .generation_dia import DiaGenerationMixin + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +@auto_docstring +class DiaPreTrainedModel(PreTrainedModel): + config_class = DiaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True + main_input_name = "input_ids" + _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DiaRMSNorm): + module.weight.data.fill_(1.0) + + +class DiaMultiChannelEmbedding(nn.Module): + """In order to efficiently compute the audio embedding from the 9 different channels, + we vectorize the embedding process by using a single embedding layer and an offset. + Example: + - num_embeds = 4 + - vocab_size = 8 + - num_channels = 3 + We would have offsets = [0, 8, 16] + If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8], + then tokens = audio_codes + offsets + = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24] + This allows us to use a single embedding layer for all channels. + """ + + def __init__(self, config: DiaDecoderConfig): + super().__init__() + self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size) + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,) + self.register_buffer("offsets", offsets, persistent=False) + + def forward(self, audio_codes: torch.Tensor) -> torch.Tensor: + tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1) + embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size) + return embeds.sum(dim=2) + + +class DiaMLP(Phi3MLP): + pass + + +class DiaRMSNorm(LlamaRMSNorm): + pass + + +class DiaRotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class DiaSelfAttention(LlamaAttention, nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False): + nn.Module.__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads) + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = is_causal + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + +class DiaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.cross_hidden_size = config.cross_hidden_size + self.num_heads = self.config.cross_num_attention_heads + self.num_key_value_heads = self.config.cross_num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.head_dim = config.cross_head_dim + self.scaling = 1 + self.attention_dropout = 0.0 + self.is_causal = False + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False + if past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx] + value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx] + else: + key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2) + + if past_key_values is not None: + # save all states to the cache + key_states, value_states = past_key_values.cross_attention_cache.update( + key_states, + value_states, + self.layer_idx, + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape((*input_shape, -1)).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DiaEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaEncoderConfig, layer_idx: int): + super().__init__() + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False) + self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.post_sa_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights + + +class DiaEncoder(DiaPreTrainedModel): + def __init__(self, config: DiaEncoderConfig): + super().__init__(config) + self.config = config + + self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.rotary_embeddings = DiaRotaryEmbedding(config) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[BaseModelOutput, tuple]: + hidden_states = self.embedding(input_ids) + + # RoPE + # Note: We expect right padding and hence always generate + # the position ids on the fly to reduce preparation overhead + position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :] + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + encoder_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class DiaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiaDecoderConfig, layer_idx: int): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True) + self.cross_attention = DiaCrossAttention(config, layer_idx) + self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + self.mlp = DiaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + self_attn_cache = past_key_values + if isinstance(self_attn_cache, EncoderDecoderCache): + self_attn_cache = self_attn_cache.self_attention_cache + + residual = hidden_states + normed_states = self.pre_sa_norm(hidden_states) + self_attn_output, self_attn_weights = self.self_attention( + normed_states, + position_embeddings, + attention_mask, + # Needs to be an arg in order to function properly + # on inplace operations to be carried (e.g. compile) + self_attn_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self_attn_output + + residual = hidden_states + normed_states = self.pre_ca_norm(hidden_states) + cross_states, cross_attn_weights = self.cross_attention( + normed_states, + encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + hidden_states = residual + cross_states + + residual = hidden_states + normed_states = self.pre_mlp_norm(hidden_states) + mlp_out = self.mlp(normed_states) + hidden_states = residual + mlp_out + + return hidden_states, self_attn_weights, cross_attn_weights + + +class DiaDecoder(DiaPreTrainedModel): + """Transformer Decoder Stack using DenseGeneral.""" + + def __init__(self, config: DiaDecoderConfig): + super().__init__(config) + self.num_channels = config.num_channels + self.vocab_size = config.vocab_size + self.embeddings = DiaMultiChannelEmbedding(config) + self.rotary_embeddings = DiaRotaryEmbedding(config) + self.layers = nn.ModuleList( + [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps) + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`): + The original `decoder_input_ids` in 3D shape to facilitate more efficient computations. + + [What are input IDs?](../glossary#input-ids) + """ + + batch_size, seq_length = input_ids.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=input_ids.device + ) + if position_ids is None: + position_ids = cache_position[None, :] + + # RoPE + hidden_states = self.embeddings(input_ids) + position_embeddings = self.rotary_embeddings(hidden_states, position_ids) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device) + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + hidden_states.shape[:2], + hidden_states, + ) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + position_embeddings, + attention_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns = all_self_attns + (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +@auto_docstring( + custom_intro=""" + The bare Dia model outputting raw hidden-states without any specific head on top. + """ +) +class DiaModel(DiaPreTrainedModel): + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.encoder = DiaEncoder(config.encoder_config) + self.decoder = DiaDecoder(config.decoder_config) + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + """ + + if input_ids is None and encoder_outputs is None: + raise ValueError( + "You should either provide text ids or the cached text encodings. Neither has been found." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if self.is_gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput + elif not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # On default we initialize the decoder with bos tokens if nothing has been provided + bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels) + if decoder_input_ids is None: + decoder_input_ids = torch.full( + size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device + ) + # Ensure 3D + if decoder_input_ids.ndim == 2: + decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + position_ids=decoder_position_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs[0], + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top. + """ +) +class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin): + base_model_prefix = "model" + + def __init__(self, config: DiaConfig): + super().__init__(config) + self.config = config + self.model = DiaModel(config) + + self.num_channels = config.decoder_config.num_channels + self.vocab_size = config.decoder_config.vocab_size + self.logits_dense = nn.Linear( + config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False + ) + self.loss_type = "ForMaskedLM" + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[tuple, Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length) + or (batch_size, target_sequence_length, num_codebooks)`, *optional*): + 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where + the audio input codebooks are flattened into the batch dimension. This also aligns with the flat- + tened audio logits which are used to calculate the loss. + + 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of + Dia to calculate embeddings and subsequent steps more efficiently. + + If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape + `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See + [`DiaProcessor.__call__`] for more details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. + Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`. + + [What are position IDs?](../glossary#position-ids) + labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in + `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100` + are ignored (masked). + """ + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_position_ids=decoder_position_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + last_hidden_state = outputs[0] + batch_size = last_hidden_state.shape[0] + # 3D <-> 2D makes it necessary to prioritize channel dim + audio_logits = ( + self.logits_dense(last_hidden_state) + .view((batch_size, -1, self.num_channels, self.vocab_size)) + .transpose(1, 2) + .contiguous() + .view(batch_size * self.num_channels, -1, self.vocab_size) + ) + + loss = None + if labels is not None: + loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=audio_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"] diff --git a/src/transformers/models/dia/processing_dia.py b/src/transformers/models/dia/processing_dia.py new file mode 100644 index 0000000000..e50ef5de67 --- /dev/null +++ b/src/transformers/models/dia/processing_dia.py @@ -0,0 +1,484 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for Dia""" + +import math +from pathlib import Path +from typing import Optional, Union + +from ...audio_utils import AudioInput, make_list_of_audio +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...utils import is_soundfile_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_soundfile_available(): + import soundfile as sf + + +class DiaAudioKwargs(AudioKwargs, total=False): + bos_token_id: int + eos_token_id: int + pad_token_id: int + delay_pattern: list[int] + generation: bool + + +class DiaProcessorKwargs(ProcessingKwargs, total=False): + audio_kwargs: DiaAudioKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "padding_side": "right", + "add_special_tokens": False, + }, + "audio_kwargs": { + "eos_token_id": 1024, + "pad_token_id": 1025, + "bos_token_id": 1026, + "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15], + "generation": True, + "sampling_rate": 44100, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class DiaProcessor(ProcessorMixin): + r""" + Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into + a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio- + nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more + information. + + Args: + feature_extractor (`DiaFeatureExtractor`): + An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`DiaTokenizer`): + An instance of [`DiaTokenizer`]. The tokenizer is a required input. + audio_tokenizer (`DacModel`): + An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input. + """ + + feature_extractor_class = "DiaFeatureExtractor" + tokenizer_class = "DiaTokenizer" + audio_tokenizer_class = "DacModel" + + def __init__(self, feature_extractor, tokenizer, audio_tokenizer): + super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer) + + @property + def model_input_names(self): + """ + We no longer pass the raw audio values but the codebooks encoded by the `audio_tokenizer`. + Conventions may differ between audio models due to architectural choices. + """ + tokenizer_input_names = self.tokenizer.model_input_names + audio_tokenizer_input_names = ["decoder_input_ids", "decoder_attention_mask"] + return list(dict.fromkeys(tokenizer_input_names + audio_tokenizer_input_names)) + + def __call__( + self, + text: Union[str, list[str]], + audio: Optional[AudioInput] = None, + output_labels: Optional[bool] = False, + **kwargs: Unpack[DiaProcessorKwargs], + ): + """ + Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is + forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the + DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer + to the docstring of the above methods for more information. + """ + if not is_torch_available(): + raise ValueError( + "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't " + "find it in your environment. You can install torch via `pip install torch`." + ) + + if text is None: + raise ValueError("You need to specify the `text` input to process.") + + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + + text_kwargs = output_kwargs["text_kwargs"] + audio_kwargs = output_kwargs["audio_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + return_tensors = common_kwargs.pop("return_tensors", None) + if return_tensors != "pt": + raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.") + + data = {} + + # Text + if isinstance(text, str): + text = [text] + elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + encodings = self.tokenizer(text, **text_kwargs) + data.update(encodings) + + # Audio + delay_pattern = audio_kwargs.pop("delay_pattern", None) + audio_bos_token_id = audio_kwargs.pop("bos_token_id", None) + audio_eos_token_id = audio_kwargs.pop("eos_token_id", None) + audio_pad_token_id = audio_kwargs.pop("pad_token_id", None) + generation = audio_kwargs.pop("generation", True) + if ( + audio_bos_token_id is None + or audio_eos_token_id is None + or audio_pad_token_id is None + or delay_pattern is None + ): + raise ValueError( + "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, " + "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those." + ) + + if generation and output_labels: + raise ValueError( + f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}." + ) + + batch_size = data["input_ids"].shape[0] + num_channels = len(delay_pattern) + max_delay = max(delay_pattern) + + # Voice cloning generation / general training + if audio is not None: + audio = make_list_of_audio(audio) + input_audios = self.feature_extractor(audio, **audio_kwargs) + + compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios) + max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate + + decoder_input_ids = [] + decoder_attention_mask = [] + # TODO: dac with batching is currently broken, but non-batch is working + # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script + for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]): + # get current length with hop length in mind (as if it were sampled as a single audio) + base_pad_len = self.feature_extractor.hop_length + current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len + + encoded_sequence_len = current_audio_len // compression_rate + padding_len = max_encoded_sequence_len - encoded_sequence_len + + # compute non-padded forward pass; one extra bos (and eos if training) is added + with torch.no_grad(): + audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device) + input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2) + + if not generation: + input_ids = torch.nn.functional.pad( + input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id + ) + + # apply padding + # +1 for the bos within the real sequence + input_ids = torch.nn.functional.pad( + input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id + ) + num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay + num_valid_inputs += 0 if generation else 1 # eos if training + attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :] + + decoder_input_ids.append(input_ids) + decoder_attention_mask.append(attention_mask) + + decoder_input_ids = torch.cat(decoder_input_ids, dim=0) + decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0) + # TTS generation + elif generation: + # all bos to start with TTS + decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long) + + # we preemptively add the delay + decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long) + else: + raise ValueError("If you try to train, you should provide audio data as well.") + + if batch_size != decoder_input_ids.shape[0]: + raise ValueError( + f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and " + f"audio samples = {decoder_input_ids.shape[0]} instead." + ) + + # prepare shift indices per delay + max_seq_len = decoder_attention_mask.shape[-1] + max_audio_len = max_seq_len - max_delay + precomputed_idx = self.build_indices( + bsz=batch_size, + seq_len=max_seq_len, + num_channels=num_channels, + delay_pattern=delay_pattern, + revert=False, + ) + + # create delay pattern input + # the pad token will be used for masking which input is valid for prediction during generation + prefill = torch.full( + (batch_size, max_seq_len, num_channels), + fill_value=audio_pad_token_id, + dtype=torch.int, + ) + prefill[:, :max_audio_len] = decoder_input_ids + + delayed_decoder_input_ids = self.apply_audio_delay( + audio=prefill, + pad_token_id=audio_pad_token_id, + bos_token_id=audio_bos_token_id, + precomputed_idx=precomputed_idx, + ) + + data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask}) + + if output_labels: + # Base idea is to shift on the sequence dim + labels = data["decoder_input_ids"].clone()[:, 1:] + labels[labels == audio_pad_token_id] = -100 + labels[labels == audio_bos_token_id] = -100 + + data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long() + data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1] + data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1] + + return BatchFeature(data=data, tensor_type=return_tensors) + + def batch_decode( + self, + decoder_input_ids: "torch.Tensor", + audio_prompt_len: Optional[int] = None, + **kwargs: Unpack[DiaProcessorKwargs], + ) -> list["torch.Tensor"]: + """ + Decodes a batch of audio codebook sequences into their respective audio waveforms via the + `audio_tokenizer`. See [`~DacModel.decode`] for more information. + + Args: + decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder. + audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning). + """ + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + + delay_pattern = audio_kwargs.pop("delay_pattern", None) + audio_bos_token_id = audio_kwargs.pop("bos_token_id", None) + audio_pad_token_id = audio_kwargs.pop("pad_token_id", None) + if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None: + raise ValueError( + "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, " + "and `delay_pattern`. You may have accidentally overwritten one of those." + ) + + # either decode the whole audio sequence or only the generated parts + if audio_prompt_len is not None: + audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long) + start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0]) + else: + start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1) + # -1 for the eos token + end_of_generation_idx = ( + decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1 + ) + + # revert delay + bsz, seq_len, num_channels = decoder_input_ids.shape + precomputed_idx = self.build_indices( + bsz=bsz, + seq_len=seq_len, + num_channels=num_channels, + delay_pattern=delay_pattern, + revert=True, + ) + + output_sequences = self.apply_audio_delay( + audio=decoder_input_ids, + # We do not care about these values as we cut them out + # with `start_of_generation_idx` and `end_of_generation_idx` + pad_token_id=-1, + bos_token_id=-1, + precomputed_idx=precomputed_idx, + ).transpose(1, 2) + + # retrieve the correct sequences each + audios = [] + # TODO: see above, dac doesn't work in batches yet + with torch.no_grad(): + for i in range(start_of_generation_idx.shape[0]): + output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...] + output_i = output_i.to(self.audio_tokenizer.device) + audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze() + audios.append(audio_i) + + return audios + + def decode( + self, + decoder_input_ids: "torch.Tensor", + audio_prompt_len: Optional[int] = None, + **kwargs: Unpack[DiaProcessorKwargs], + ) -> "torch.Tensor": + """ + Decodes a single sequence of audio codebooks into the respective audio waveform via the + `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information. + """ + if decoder_input_ids.shape[0] != 1: + raise ValueError( + f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead." + ) + + return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0] + + def get_audio_prompt_len( + self, + decoder_attention_mask: "torch.Tensor", + **kwargs: Unpack[DiaProcessorKwargs], + ) -> int: + """Utility function to get the audio prompt length.""" + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + + delay_pattern = audio_kwargs.pop("delay_pattern", None) + if delay_pattern is None: + raise ValueError( + "To enable the utility of retrieving the prompt length for Dia, we need the " + "`delay_pattern`. You may have accidentally overwritten this." + ) + return decoder_attention_mask.shape[1] - max(delay_pattern) + + # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia + def save_audio( + self, + audio: AudioInput, + saving_path: Union[str, Path, list[Union[str, Path]]], + **kwargs: Unpack[DiaProcessorKwargs], + ): + # TODO: @eustlb, this should be in AudioProcessor + if not is_soundfile_available(): + raise ImportError("Please install `soundfile` to save audio files.") + + # ensure correct audio input + audio = make_list_of_audio(audio) + + # ensure correct saving path + if isinstance(saving_path, (str, Path)): + saving_path = [saving_path] + elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)): + raise ValueError("Invalid input path. Please provide a string, or a list of strings") + + if len(audio) != len(saving_path): + raise ValueError("The number of audio and saving paths must be the same") + + output_kwargs = self._merge_kwargs( + DiaProcessorKwargs, + **kwargs, + ) + audio_kwargs = output_kwargs["audio_kwargs"] + sampling_rate = audio_kwargs["sampling_rate"] + + for audio_value, p in zip(audio, saving_path): + if isinstance(audio_value, torch.Tensor): + audio_value = audio_value.cpu().float().numpy() + sf.write(p, audio_value, sampling_rate) + + @staticmethod + def build_indices( + bsz: int, + seq_len: int, + num_channels: int, + delay_pattern: list[int], + revert: bool = False, + ) -> tuple["torch.Tensor", "torch.Tensor"]: + """ + Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel] + or in[seq, channel] = out[seq + delay[channel], channel] if `revert`. + Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD. + """ + delay_array = torch.tensor(delay_pattern, dtype=torch.int32) + + # (0..seq_len-1) + sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None] + # + or - delay depending if we delay or revert the delay + if not revert: + sequence_idx = sequence_idx - delay_array[None, None, :] + else: + sequence_idx = sequence_idx + delay_array[None, None, :] + # if delay goes over the range we clamp back to valid values + valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1) + + batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels) + channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels) + + all_idx = torch.stack( + [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)], + dim=1, + ).long() + + return sequence_idx, all_idx + + @staticmethod + def apply_audio_delay( + audio: "torch.Tensor", + pad_token_id: int, + bos_token_id: int, + precomputed_idx: tuple["torch.Tensor", "torch.Tensor"], + ) -> "torch.Tensor": + """ + Applies or reverts the delay pattern to batched audio tokens using precomputed indices, + inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len. + + Args: + audio: audio tokens of shape [bsz, seq_len, num_channels] + pad_token_id: the PAD token + bos_token_id: the BOS token + precomputed_idx: from `build_indices` + + Returns: + final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels] + """ + # Move everything to the same device + device = audio.device + sequence_idx, all_idx = precomputed_idx + sequence_idx = sequence_idx.to(device) + all_idx = all_idx.to(device) + + # Gather per precomputed indices + batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1) + gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size()) + + # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD + mask_bos = sequence_idx < 0 + mask_pad = sequence_idx >= audio.shape[1] + final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio)) + + return final_audio + + +__all__ = ["DiaProcessor"] diff --git a/src/transformers/models/dia/tokenization_dia.py b/src/transformers/models/dia/tokenization_dia.py new file mode 100644 index 0000000000..4e205906ea --- /dev/null +++ b/src/transformers/models/dia/tokenization_dia.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Dia.""" + +from typing import Optional + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DiaTokenizer(PreTrainedTokenizer): + """ + Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + max_length (`int`, *optional*, defaults to 1024): + The maximum length of the sequences when encoding. Sequences longer than this will be truncated. + offset (`int`, *optional*, defaults to 0): + The offset of the tokenizer. + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + pad_token: Optional[str] = "", + unk_token: Optional[str] = "", + max_length: Optional[int] = 1024, + offset: int = 0, + **kwargs, + ): + # We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well + pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token + unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token + + self._utf_vocab_size = 2**8 # utf is 8 bits + self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")} + self.offset = offset + super().__init__( + unk_token=unk_token, + pad_token=pad_token, + max_length=max_length, + **kwargs, + ) + + @property + def vocab_size(self): + return self._utf_vocab_size + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + tokens = [chr(i) for i in text.encode("utf-8")] + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + + if len(token) != 1: + token_id = None + else: + token_id = ord(token) + self.offset + + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = chr(index - self.offset) + return token + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + for token in tokens: + if token in self.added_tokens_decoder: + added_token_obj = self.added_tokens_decoder[token] + tok_string = str(added_token_obj).encode("utf-8") + elif token in self.added_tokens_encoder: + tok_string = token.encode("utf-8") + else: + tok_string = token.encode("utf-8") # Assume general string token + bstring += tok_string + string = bstring.decode("utf-8", errors="ignore") + return string + + # No vocab file + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + return () + + +__all__ = ["DiaTokenizer"] diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 79b0e9b35f..afeae13ae7 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -80,15 +80,21 @@ class TextToAudioPipeline(Pipeline): See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech). """ + # Introducing the processor at load time for new behaviour + _load_processor = True + _pipeline_calls_generate = True # Make sure the docstring is updated when the default generation config is changed _default_generation_config = GenerationConfig( max_new_tokens=256, ) - def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs): + def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs): super().__init__(*args, **kwargs) + # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time + self.no_processor = no_processor + if self.framework == "tf": raise ValueError("The TextToAudioPipeline is only available in PyTorch.") @@ -117,6 +123,10 @@ class TextToAudioPipeline(Pipeline): if sampling_rate is not None: self.sampling_rate = sampling_rate + # last fallback to get the sampling rate based on processor + if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"): + self.sampling_rate = self.processor.feature_extractor.sampling_rate + def preprocess(self, text, **kwargs): if isinstance(text, str): text = [text] @@ -136,7 +146,8 @@ class TextToAudioPipeline(Pipeline): kwargs = new_kwargs - output = self.tokenizer(text, **kwargs, return_tensors="pt") + preprocessor = self.tokenizer if self.no_processor else self.processor + output = preprocessor(text, **kwargs, return_tensors="pt") return output @@ -228,12 +239,21 @@ class TextToAudioPipeline(Pipeline): return preprocess_params, params, postprocess_params - def postprocess(self, waveform): + def postprocess(self, audio): output_dict = {} - if isinstance(waveform, dict): - waveform = waveform["waveform"] - elif isinstance(waveform, tuple): - waveform = waveform[0] + + # We directly get the waveform + if self.no_processor: + if isinstance(audio, dict): + waveform = audio["waveform"] + elif isinstance(audio, tuple): + waveform = audio[0] + else: + waveform = audio + # Or we need to postprocess to get the waveform + else: + waveform = self.processor.decode(audio) + output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy() output_dict["sampling_rate"] = self.sampling_rate diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index e7adab4b1d..2a97cde3cc 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -49,6 +49,7 @@ from .tokenization_utils_base import ( TruncationStrategy, ) from .utils import ( + AUDIO_TOKENIZER_NAME, CHAT_TEMPLATE_DIR, CHAT_TEMPLATE_FILE, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, @@ -61,12 +62,17 @@ from .utils import ( download_url, is_offline_mode, is_remote_url, + is_torch_available, list_repo_templates, logging, ) from .utils.deprecation import deprecate_kwarg +if is_torch_available(): + from .modeling_utils import PreTrainedAudioTokenizerBase + + logger = logging.get_logger(__name__) # Dynamically import the Transformers module to grab the attribute classes of the processor from their names. @@ -499,7 +505,7 @@ class ProcessorMixin(PushToHubMixin): """ attributes = ["feature_extractor", "tokenizer"] - optional_attributes = ["chat_template"] + optional_attributes = ["chat_template", "audio_tokenizer"] optional_call_args: list[str] = [] # Names need to be attr_class for attr in attributes feature_extractor_class = None @@ -511,7 +517,19 @@ class ProcessorMixin(PushToHubMixin): # First, extract optional attributes from kwargs if present # Optional attributes can never be positional arguments for optional_attribute in self.optional_attributes: - setattr(self, optional_attribute, kwargs.pop(optional_attribute, None)) + optional_attribute_value = kwargs.pop(optional_attribute, None) + setattr(self, optional_attribute, optional_attribute_value) + + # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights + if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None: + proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value) + + if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)): + raise ValueError( + f"Tried to use `{proper_class}` for audio tokenization. However, this class is not" + " registered for audio tokenization." + ) + # Sanitize args and kwargs for key in kwargs: if key not in self.attributes: @@ -530,21 +548,30 @@ class ProcessorMixin(PushToHubMixin): # Check each arg is of the proper class (this will also catch a user initializing in the wrong order) for attribute_name, arg in kwargs.items(): - class_name = getattr(self, f"{attribute_name}_class") - # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class. - class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) - if isinstance(class_name, tuple): - proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None) - else: - proper_class = self.get_possibly_dynamic_module(class_name) - - if not isinstance(arg, proper_class): - raise TypeError( - f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected." - ) - + self.check_argument_for_proper_class(attribute_name, arg) setattr(self, attribute_name, arg) + def check_argument_for_proper_class(self, argument_name, argument): + """ + Checks the passed argument's class against the expected transformers class. In case of an unexpected + mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class + is returned. + """ + class_name = getattr(self, f"{argument_name}_class") + # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class. + class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) + if isinstance(class_name, tuple): + proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None) + else: + proper_class = self.get_possibly_dynamic_module(class_name) + + if not isinstance(argument, proper_class): + raise TypeError( + f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected." + ) + + return proper_class + def to_dict(self) -> dict[str, Any]: """ Serializes this instance to a Python dictionary. @@ -577,6 +604,8 @@ class ProcessorMixin(PushToHubMixin): del output["feature_extractor"] if "chat_template" in output: del output["chat_template"] + if "audio_tokenizer" in output: + del output["audio_tokenizer"] # Some attributes have different names but containing objects that are not simple strings output = { @@ -695,6 +724,7 @@ class ProcessorMixin(PushToHubMixin): save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE ) # Legacy filename chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR) + output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME) processor_dict = self.to_dict() # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` @@ -737,6 +767,19 @@ class ProcessorMixin(PushToHubMixin): "separate files using the `save_jinja_files` argument." ) + if self.audio_tokenizer is not None: + audio_tokenizer_class = self.audio_tokenizer.__class__.__name__ + audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path + + audio_tokenizer_dict = { + "audio_tokenizer_class": audio_tokenizer_class, + "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path, + } + audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n" + + with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer: + writer.write(audio_tokenizer_json) + # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and # `auto_map` is not specified. if set(processor_dict.keys()) != {"processor_class"}: @@ -774,6 +817,9 @@ class ProcessorMixin(PushToHubMixin): Returns: `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object. """ + # holding a copy for optionally loading the audio tokenizer (if available) + audio_tokenizer_kwargs = copy.deepcopy(kwargs) + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", None) @@ -803,16 +849,18 @@ class ProcessorMixin(PushToHubMixin): resolved_additional_chat_template_files = {} if os.path.isfile(pretrained_model_name_or_path): resolved_processor_file = pretrained_model_name_or_path - # can't load chat-template when given a file as pretrained_model_name_or_path + # can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path resolved_chat_template_file = None resolved_raw_chat_template_file = None + resolved_audio_tokenizer_file = None is_local = True elif is_remote_url(pretrained_model_name_or_path): processor_file = pretrained_model_name_or_path resolved_processor_file = download_url(pretrained_model_name_or_path) - # can't load chat-template when given a file url as pretrained_model_name_or_path + # can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path resolved_chat_template_file = None resolved_raw_chat_template_file = None + resolved_audio_tokenizer_file = None else: if is_local: template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) @@ -899,6 +947,21 @@ class ProcessorMixin(PushToHubMixin): ) for template_name, template_file in additional_chat_template_files.items() } + + resolved_audio_tokenizer_file = cached_file( + pretrained_model_name_or_path, + AUDIO_TOKENIZER_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) except OSError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -939,6 +1002,22 @@ class ProcessorMixin(PushToHubMixin): if chat_templates: kwargs["chat_template"] = chat_templates + # Same as chat template, adding as kwarg after loading the model + audio_tokenizer = None + if resolved_audio_tokenizer_file is not None: + with open(resolved_audio_tokenizer_file, "r", encoding="utf-8") as reader: + # The json contains the references we need to init the correct model + audio_tokenizer_references = json.load(reader) + audio_tokenizer_class = cls.get_possibly_dynamic_module( + audio_tokenizer_references["audio_tokenizer_class"] + ) + audio_tokenizer_path = audio_tokenizer_references["audio_tokenizer_name_or_path"] + + audio_tokenizer = audio_tokenizer_class.from_pretrained(audio_tokenizer_path, **audio_tokenizer_kwargs) + + if audio_tokenizer is not None: + kwargs["audio_tokenizer"] = audio_tokenizer + # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict. # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception) @@ -947,7 +1026,9 @@ class ProcessorMixin(PushToHubMixin): # In any case we need to pass `chat_template` if it is available processor_dict = {} if "chat_template" in kwargs: - processor_dict = {"chat_template": kwargs.pop("chat_template")} + processor_dict["chat_template"] = kwargs.pop("chat_template") + if "audio_tokenizer" in kwargs: + processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer") return processor_dict, kwargs try: @@ -972,6 +1053,8 @@ class ProcessorMixin(PushToHubMixin): if "chat_template" in kwargs: processor_dict["chat_template"] = kwargs.pop("chat_template") + if "audio_tokenizer" in kwargs: + processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer") return processor_dict, kwargs @@ -1276,6 +1359,7 @@ class ProcessorMixin(PushToHubMixin): attribute_class = cls.get_possibly_dynamic_module(class_name) args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) + return args @staticmethod @@ -1287,6 +1371,7 @@ class ProcessorMixin(PushToHubMixin): transformers_module.VIDEO_PROCESSOR_MAPPING, transformers_module.TOKENIZER_MAPPING, transformers_module.FEATURE_EXTRACTOR_MAPPING, + transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING, ] for lookup_location in lookup_locations: for custom_class in lookup_location._extra_content.values(): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7ca4c35528..4943e91e73 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -292,6 +292,7 @@ CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" IMAGE_PROCESSOR_NAME = "preprocessor_config.json" VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json" +AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json" PROCESSOR_NAME = "processor_config.json" GENERATION_CONFIG_NAME = "generation_config.json" MODEL_CARD_NAME = "modelcard.json" diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index ea0a7581e5..834c502b1a 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -56,7 +56,12 @@ if is_torch_available(): UnbatchedClassifierFreeGuidanceLogitsProcessor, WatermarkLogitsProcessor, ) - from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor + from transformers.generation.logits_process import ( + BarkEosPrioritizerLogitsProcessor, + DiaClassifierFreeGuidanceLogitsProcessor, + DiaEOSChannelFilterLogitsProcessor, + DiaEOSDelayPatternLogitsProcessor, + ) @require_torch @@ -1211,3 +1216,145 @@ class LogitsProcessorTest(unittest.TestCase): ) ) self.assertTrue(is_close) + + def test_dia_classifier_free_guidance(self): + input_ids = torch.LongTensor([[0]]) + logits_uncond = torch.tensor([[1.0, 0, 1.5]]) + logits_cond = torch.tensor([[1.0, 1.0, 1.0]]) + + # base cfg with conditioned as center + cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5) + out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0)) + + res = logits_cond + 1.5 * (logits_cond - logits_uncond) + + self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item()) + self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item()) + self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item()) + + # additional top k (on cond logits) + cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5, guidance_top_k=1) + out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0)) + + res = logits_cond + 1.5 * (logits_cond - logits_uncond) + mask = res == res.max() + res = logits_cond.clone() + res[~mask.bool()] = -float("inf") + + self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item()) + self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item()) + self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item()) + + def test_dia_channel_filter(self): + eos = 2 + bsz, channels, vocab = 2, 2, 4 + + input_ids = torch.LongTensor([[0]]) + logits = torch.zeros(size=(bsz, channels, vocab)).view(bsz * channels, vocab) + logits[0, eos] = 1 # Eos max (forced) + logits[1, eos] = 1 # Eos max (forced) but not channel 0 + + channel_filter = DiaEOSChannelFilterLogitsProcessor(num_channels=channels, eos_token_id=eos) + out = channel_filter(input_ids, logits).view(bsz, channels, vocab) + + for i in range(vocab): + if i > eos: + # special tokens are not to be predicted + self.assertTrue((out[:, :, i] == -float("inf")).all()) + elif i == eos: + # Eos forced on channel 0 + self.assertTrue(out[0, 0, i] == 1) + # Eos suppressed on everything else (even if max before) + self.assertTrue(out[0, 1, i] == -float("inf")) + self.assertTrue((out[1, :, i] == -float("inf")).all()) + else: + # Eos forced on channel 0 + self.assertTrue(out[0, 0, i] == -float("inf")) + # previous values + self.assertTrue(out[0, 1, i] == 0) + self.assertTrue((out[1, :, i] == 0).all()) + + def test_dia_delay_pattern(self): + def check_eos_logits(out, logits, batch, channel, eos): + for i in range(vocab): + if i == eos: + self.assertTrue(out[batch, channel, i] == 0) + else: + self.assertTrue(out[batch, channel, i] == -float("inf")) + + for c in range(channel): + if c != channel: + self.assertTrue((out[batch, c] == logits[batch, c]).all()) + + eos = 2 + delay_pattern = [0, 2, 3] + max_generation_len = 10 + bsz, channels, vocab = 2, 3, 4 + + input_ids = torch.LongTensor([[0]]) + logits = torch.zeros(size=(bsz, channels, vocab)) + # Ensure that argmax can not result in eos + logits[:, :, eos] = -1 + + delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor( + delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len + ) + out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab) + + # Nothing should happen except for init of some attributes + self.assertTrue((out == logits).all()) + self.assertTrue((~delay_pattern_processor.active_batches).all()) + self.assertTrue( + (delay_pattern_processor.delay_pattern == torch.tensor([delay_pattern for _ in range(bsz)])).all() + ) + + # Make first batch end + logits[0, 0, eos] = 1 + + # Go through the complete delay pattern + for i in range(max(delay_pattern) + 1): + out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab) + + # no delay should kick in + if i == 1: + self.assertTrue((out == logits).all()) + else: + j = i if i == 0 else i - 1 + check_eos_logits(out=out, logits=logits, batch=0, channel=j, eos=eos) + self.assertTrue((out[1] == logits[1]).all()) + self.assertTrue(delay_pattern_processor.active_batches[0]) + self.assertFalse(delay_pattern_processor.active_batches[1]) + self.assertTrue( + ( + delay_pattern_processor.delay_pattern[0] + == torch.tensor([delay - (i + 1) for delay in delay_pattern]) + ).all() + ) + self.assertTrue((delay_pattern_processor.delay_pattern[1] == torch.tensor(delay_pattern)).all()) + + # Make second batch end + logits[1, 0, eos] = 1 + + # Just to check if other batches could work + out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab) + + self.assertTrue((out[0] == logits[0]).all()) + self.assertTrue(delay_pattern_processor.active_batches.all()) + self.assertTrue( + (delay_pattern_processor.delay_pattern[0] == torch.tensor([delay - 5 for delay in delay_pattern])).all() + ) + self.assertTrue( + (delay_pattern_processor.delay_pattern[1] == torch.tensor([delay - 1 for delay in delay_pattern])).all() + ) + + # Last check on max generation length reached (with delay in mind until last channel produces eos) + input_ids = torch.LongTensor([[0] * (max_generation_len - max(delay_pattern) - 1)]) + delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor( + delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len + ) + out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab) + + check_eos_logits(out=out, logits=logits, batch=0, channel=0, eos=eos) + check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos) + self.assertTrue(delay_pattern_processor.active_batches.all()) + self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all()) diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py index 2a1bc30dbb..60500001a3 100644 --- a/tests/models/auto/test_processor_auto.py +++ b/tests/models/auto/test_processor_auto.py @@ -26,6 +26,7 @@ import transformers from transformers import ( CONFIG_MAPPING, FEATURE_EXTRACTOR_MAPPING, + MODEL_FOR_AUDIO_TOKENIZATION_MAPPING, PROCESSOR_MAPPING, TOKENIZER_MAPPING, AutoConfig, @@ -265,6 +266,8 @@ class AutoFeatureExtractorTest(unittest.TestCase): del TOKENIZER_MAPPING._extra_content[CustomConfig] if CustomConfig in PROCESSOR_MAPPING._extra_content: del PROCESSOR_MAPPING._extra_content[CustomConfig] + if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content: + del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig] def test_from_pretrained_dynamic_processor_conflict(self): class NewFeatureExtractor(Wav2Vec2FeatureExtractor): @@ -317,6 +320,8 @@ class AutoFeatureExtractorTest(unittest.TestCase): del TOKENIZER_MAPPING._extra_content[CustomConfig] if CustomConfig in PROCESSOR_MAPPING._extra_content: del PROCESSOR_MAPPING._extra_content[CustomConfig] + if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content: + del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig] def test_from_pretrained_dynamic_processor_with_extra_attributes(self): class NewFeatureExtractor(Wav2Vec2FeatureExtractor): @@ -356,6 +361,8 @@ class AutoFeatureExtractorTest(unittest.TestCase): del TOKENIZER_MAPPING._extra_content[CustomConfig] if CustomConfig in PROCESSOR_MAPPING._extra_content: del PROCESSOR_MAPPING._extra_content[CustomConfig] + if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content: + del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig] def test_dynamic_processor_with_specific_dynamic_subcomponents(self): class NewFeatureExtractor(Wav2Vec2FeatureExtractor): @@ -390,6 +397,8 @@ class AutoFeatureExtractorTest(unittest.TestCase): del TOKENIZER_MAPPING._extra_content[CustomConfig] if CustomConfig in PROCESSOR_MAPPING._extra_content: del PROCESSOR_MAPPING._extra_content[CustomConfig] + if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content: + del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig] def test_auto_processor_creates_tokenizer(self): processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert") diff --git a/tests/models/dia/__init__.py b/tests/models/dia/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/dia/test_feature_extraction_dia.py b/tests/models/dia/test_feature_extraction_dia.py new file mode 100644 index 0000000000..6243dc4791 --- /dev/null +++ b/tests/models/dia/test_feature_extraction_dia.py @@ -0,0 +1,231 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Dia feature extractor.""" + +import itertools +import random +import unittest + +import numpy as np + +from transformers import DiaFeatureExtractor +from transformers.testing_utils import require_torch +from transformers.utils.import_utils import is_torch_available + +from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin + + +if is_torch_available(): + import torch + + +global_rng = random.Random() + + +# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +@require_torch +class DiaFeatureExtractionTester: + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.__init__ + def __init__( + self, + parent, + batch_size=7, + min_seq_length=400, + max_seq_length=2000, + feature_size=1, + padding_value=0.0, + sampling_rate=16000, + hop_length=512, + ): + self.parent = parent + self.batch_size = batch_size + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.hop_length = hop_length + self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1) + self.feature_size = feature_size + self.padding_value = padding_value + self.sampling_rate = sampling_rate + + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.prepare_feat_extract_dict + def prepare_feat_extract_dict(self): + return { + "feature_size": self.feature_size, + "padding_value": self.padding_value, + "sampling_rate": self.sampling_rate, + "hop_length": self.hop_length, + } + + # Copied from tests.models.encodec.test_feature_extraction_encodec.EnCodecFeatureExtractionTester.prepare_inputs_for_common + def prepare_inputs_for_common(self, equal_length=False, numpify=False): + def _flatten(list_of_lists): + return list(itertools.chain(*list_of_lists)) + + if equal_length: + audio_inputs = floats_list((self.batch_size, self.max_seq_length)) + else: + # make sure that inputs increase in size + audio_inputs = [ + _flatten(floats_list((x, self.feature_size))) + for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff) + ] + + if numpify: + audio_inputs = [np.asarray(x) for x in audio_inputs] + + return audio_inputs + + +@require_torch +class DiaFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): + feature_extraction_class = DiaFeatureExtractor + + def setUp(self): + self.feat_extract_tester = DiaFeatureExtractionTester(self) + + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_call + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + # create three inputs of length 800, 1000, and 1200 + audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs] + + # Test not batched input + encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values + encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test batched + encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values + encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_double_precision_pad + def test_double_precision_pad(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + np_audio_inputs = np.random.rand(100).astype(np.float64) + py_audio_inputs = np_audio_inputs.tolist() + + for inputs in [py_audio_inputs, np_audio_inputs]: + np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np") + self.assertTrue(np_processed.input_values.dtype == np.float32) + pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt") + self.assertTrue(pt_processed.input_values.dtype == torch.float32) + + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest._load_datasamples + def _load_datasamples(self, num_samples): + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + + return [x["array"] for x in audio_samples] + + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_integration with Dac->Dia + def test_integration(self): + # fmt: off + EXPECTED_INPUT_VALUES = torch.tensor( + [ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03, + 1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04, + 9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03, + 1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04, + 1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03, + 7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03, + 8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04, + 1.0986328e-03, 2.1057129e-03] + ) + # fmt: on + input_audio = self._load_datasamples(1) + feature_extractor = DiaFeatureExtractor() + input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"] + self.assertEqual(input_values.shape, (1, 1, 93696)) + torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4) + audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32) + torch.testing.assert_close(input_values[0, 0, -46:-16], audio_input_end, rtol=1e-4, atol=1e-4) + + def test_integration_stereo(self): + # fmt: off + EXPECTED_INPUT_VALUES = torch.tensor( + [2.3804e-03, 2.0752e-03, 1.9836e-03, 2.1057e-03, 1.6174e-03, + 3.0518e-04, 9.1553e-05, 3.3569e-04, 9.7656e-04, 1.8311e-03, + 2.0142e-03, 2.1057e-03, 1.7395e-03, 4.5776e-04, -3.9673e-04, + 4.5776e-04, 1.0071e-03, 9.1553e-05, 4.8828e-04, 1.1597e-03, + 7.3242e-04, 9.4604e-04, 1.8005e-03, 1.8311e-03, 8.8501e-04, + 4.2725e-04, 4.8828e-04, 7.3242e-04, 1.0986e-03, 2.1057e-03] + ) + # fmt: on + input_audio = self._load_datasamples(1) + input_audio = [np.tile(input_audio[0][None], reps=(2, 1))] + feature_extractor = DiaFeatureExtractor(feature_size=2) + input_values = feature_extractor(input_audio, return_tensors="pt").input_values + self.assertEqual(input_values.shape, (1, 1, 93696)) + torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4) + + # Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_truncation_and_padding with Dac->Dia + def test_truncation_and_padding(self): + input_audio = self._load_datasamples(2) + # would be easier if the stride was like + feature_extractor = DiaFeatureExtractor() + + # pad and trunc raise an error ? + with self.assertRaisesRegex( + ValueError, + "^Both padding and truncation were set. Make sure you only set one.$", + ): + truncated_outputs = feature_extractor( + input_audio, padding="max_length", truncation=True, return_tensors="pt" + ).input_values + + # force truncate to max_length + truncated_outputs = feature_extractor( + input_audio, truncation=True, max_length=48000, return_tensors="pt" + ).input_values + self.assertEqual(truncated_outputs.shape, (2, 1, 48128)) + + # pad: + padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values + self.assertEqual(padded_outputs.shape, (2, 1, 93696)) + + # force pad to max length + truncated_outputs = feature_extractor( + input_audio, padding="max_length", max_length=100000, return_tensors="pt" + ).input_values + self.assertEqual(truncated_outputs.shape, (2, 1, 100352)) + + # force no pad + with self.assertRaisesRegex( + ValueError, + "^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$", + ): + truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values + + truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values + self.assertEqual(truncated_outputs.shape, (1, 1, 93680)) diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py new file mode 100644 index 0000000000..f9427160c2 --- /dev/null +++ b/tests/models/dia/test_modeling_dia.py @@ -0,0 +1,752 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Dia model.""" + +import copy +import pathlib +import tempfile +import unittest + +import pytest + +from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig +from transformers.testing_utils import ( + cleanup, + is_flaky, + require_torch, + require_torch_accelerator, + require_torch_sdpa, + slow, + torch_device, +) +from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available +from transformers.utils.import_utils import is_datasets_available + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_datasets_available(): + from datasets import Audio, load_dataset + +if is_torch_available(): + import torch + + from transformers import ( + DiaForConditionalGeneration, + DiaModel, + DiaProcessor, + PretrainedConfig, + PreTrainedModel, + ) + from transformers.cache_utils import ( + Cache, + StaticCache, + ) + from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder + +if is_torchaudio_available(): + import torchaudio + +if is_soundfile_available(): + import soundfile as sf + + +@require_torch +class DiaModelTester: + def __init__( + self, + parent, + batch_size=3, # need batch_size != num_hidden_layers + seq_length=7, + max_length=50, + is_training=True, + vocab_size=100, + hidden_size=16, + intermediate_size=37, + num_hidden_layers=2, + num_attention_heads=2, + head_dim=8, + decoder_hidden_size=32, # typically larger than encoder + hidden_act="silu", + eos_token_id=97, # special tokens all occur after eos + pad_token_id=98, + bos_token_id=99, + delay_pattern=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.max_length = max_length + self.is_training = is_training + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.decoder_hidden_size = decoder_hidden_size + self.hidden_act = hidden_act + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + # Set default delay pattern if not provided + self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2] + self.num_channels = len(self.delay_pattern) + + def get_config(self): + encoder_config = DiaEncoderConfig( + max_position_embeddings=self.max_length, + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.hidden_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing + head_dim=self.head_dim, + intermediate_size=self.intermediate_size, + vocab_size=self.vocab_size, + hidden_act=self.hidden_act, + ) + + decoder_config = DiaDecoderConfig( + max_position_embeddings=self.max_length, + num_hidden_layers=self.num_hidden_layers, + hidden_size=self.decoder_hidden_size, + intermediate_size=self.intermediate_size, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=1, # GQA + head_dim=self.head_dim, + cross_num_attention_heads=self.num_attention_heads, + cross_head_dim=self.head_dim, + cross_num_key_value_heads=1, # GQA + cross_hidden_size=self.hidden_size, # match encoder hidden size + vocab_size=self.vocab_size, + hidden_act=self.hidden_act, + num_channels=self.num_channels, + ) + + config = DiaConfig( + encoder_config=encoder_config, + decoder_config=decoder_config, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + delay_pattern=self.delay_pattern, + ) + + return config + + def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]: + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = input_ids.ne(self.pad_token_id) + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size) + decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id) + + config = self.get_config() + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]: + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_model_forward(self, config, inputs_dict): + model = DiaModel(config=config).to(torch_device).eval() + + input_ids = inputs_dict["input_ids"] + decoder_input_ids = inputs_dict["decoder_input_ids"] + + # first forward pass + last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state + + self.parent.assertTrue( + last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size) + ) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = DiaModel(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder( + input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"] + )[0] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3) + + +@require_torch +class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else () + # We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate` + all_generative_model_classes = (DiaForConditionalGeneration,) + # TODO: support new pipeline behavior in tests + pipeline_model_mapping = {} + # pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {} + test_pruning = False + test_head_masking = False + test_resize_embeddings = False + is_encoder_decoder = True + # Indicates VLMs usually but there are many audio models which are also composite + _is_composite = True + + def setUp(self): + self.model_tester = DiaModelTester(self) + # Skipping `has_text_modality` but manually testing down below + self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig) + self.skip_non_greedy_generate() + + def skip_non_greedy_generate(self): + skippable_tests = [ + "test_sample_generate_dict_output", # return sequences > 1 + "test_beam", + "test_group_beam", + "test_constrained_beam", + "test_contrastive", + "test_assisted", + "test_dola", + "test_prompt_lookup", + "test_model_parallel_beam_search", + "test_generate_without_input_ids", + "test_generate_with_head_masking", + ] + + for test in skippable_tests: + if self._testMethodName.startswith(test): + self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.") + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + """Overriden to account for the 2D flattened structure""" + inputs_dict = copy.deepcopy(inputs_dict) + + if return_labels: + inputs_dict["labels"] = torch.ones( + ( + self.model_tester.batch_size * self.model_tester.num_channels, + self.model_tester.seq_length, + ), + dtype=torch.long, + device=torch_device, + ) + + return inputs_dict + + def test_config(self): + self.config_tester.run_common_tests() + + # Manual testing because of composite configs + config = self.model_tester.prepare_config_and_inputs()[0] + self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist") + self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist") + + def test_model_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs) + + @is_flaky + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + # Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config + # + additional special cases where 3D x 2D meshes confuse the expected shape + def _check_logits(self, batch_size, logits, config): + batch_size *= len(config.delay_pattern) # Account for flattening + vocab_size = config.decoder_config.vocab_size + self.assertIsInstance(logits, tuple) + self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits)) + # vocabulary difference equal to one (imagegptmodel?) or zero (all other models) + vocab_diff = vocab_size - logits[0].shape[-1] + self.assertTrue(vocab_diff in [0, 1]) + self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits)) + + def _check_attentions_for_generate( + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (output_length - prompt_length)) + + use_cache = decoder_past_key_values is not None + has_static_cache = isinstance(decoder_past_key_values, StaticCache) + + # When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new + # token(s) + for generated_length, iter_attentions in enumerate(attentions): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + query_length = ( + prompt_length + generated_length + if not has_static_cache + else decoder_past_key_values.get_max_cache_shape() + ) + + expected_shape = ( + batch_size, + config.decoder_config.num_attention_heads, # Decoder config + model_input_length, + query_length, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): + # Encoder config + encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length) + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [layer_attentions.shape for layer_attentions in attentions], + [encoder_expected_shape] * len(attentions), + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False + ): + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) + + # When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the + # new token(s) + for generated_length, iter_hidden_states in enumerate(hidden_states): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + + # check hidden size + # we can have different hidden sizes between encoder and decoder --> check both + expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size) + expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size) + self.assertTrue( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states] + == [expected_shape_encoder] * len(iter_hidden_states) + or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states] + == [expected_shape_decoder] * len(iter_hidden_states) + ) + + def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length): + # Encoder config + encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size) + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in hidden_states], + [encoder_expected_shape] * len(hidden_states), + ) + + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) + + # we need the decoder config here + config = config.decoder_config + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads, + ) + + if isinstance(decoder_past_key_values, Cache): + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + + def _check_scores(self, batch_size, scores, generated_length, config): + # Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab) + vocab_size = config.decoder_config.vocab_size + expected_shape = (batch_size * len(config.delay_pattern), vocab_size) + self.assertIsInstance(scores, tuple) + self.assertEqual(len(scores), generated_length) + self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Overwritten as it relies on hardcoded namings atm - checking for our case here specifically + """ + for model_class in self.all_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + + sub_models_supporting_sdpa = [ + (module._supports_sdpa or module._supports_attention_backend) + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + supports_sdpa_all_modules = ( + all(sub_models_supporting_sdpa) + if len(sub_models_supporting_sdpa) > 0 + else (model._supports_sdpa or model._supports_attention_backend) + ) + + if not supports_sdpa_all_modules: + with self.assertRaises(ValueError): + model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa") + else: + model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa") + for key in model_sdpa.config: + if isinstance(getattr(model_sdpa.config, key), PretrainedConfig): + sub_config = getattr(model_sdpa.config, key) + self.assertTrue(sub_config._attn_implementation == "sdpa") + + @pytest.mark.generate + @unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.") + def test_generate_continue_from_past_key_values(self): + """Only a small change due to the expected shapes""" + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + + generate_kwargs = { + "pad_token_id": -1, + "eos_token_id": -1, + "forced_eos_token_id": None, + "encoder_no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + inputs["decoder_attention_mask"], + (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]), + mode="constant", + value=1, + ) + + first_caches_scores = outputs_cached.scores + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) + full_cached_scores = first_caches_scores + outputs_cached.scores + outputs_cached.scores = full_cached_scores + + # The two sets of generated text and past kv should be equal to each other + self._check_similar_generate_outputs(outputs, outputs_cached) + for layer_idx in range(len(outputs_cached.past_key_values)): + for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + outputs_cached.past_key_values[layer_idx][kv_idx], + ) + ) + + @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") + def test_past_key_values_format(self, custom_all_cache_shapes=None): + pass + + @unittest.skip(reason="Indirectly checked in Dia through the generate methods.") + def test_hidden_states_output(self): + pass + + @unittest.skip( + reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings" + ) + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Theoretically works but kernel library causes issues.") + def test_torchscript_output_hidden_state(self): + pass + + @unittest.skip(reason="Theoretically works but kernel library causes issues.") + def test_torchscript_simple(self): + pass + + @unittest.skip(reason="Encoder-Decoder cache can not be initialized.") + def test_multi_gpu_data_parallel_forward(self): + pass + + +class DiaForConditionalGenerationIntegrationTest(unittest.TestCase): + """ + See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests + + NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia + (It doesn't change the behaviour as we cut by the eos token position) + """ + + def setUp(self): + # it's a dummy ckpt but should suffice for testing purposes + self.model_checkpoint = "AntonV/Dia-1.6B" + self.sampling_rate = 44100 + + # prepare audio + librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate)) + audio_sample_1 = librispeech_dummy[-1]["audio"]["array"] + audio_sample_2 = librispeech_dummy[-2]["audio"]["array"] + # 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia + dac_chunk_len = 512 + self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3" + self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3" + sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate) + sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate) + + def tearDown(self): + pathlib.Path(self.audio_prompt_1_path).unlink() + pathlib.Path(self.audio_prompt_2_path).unlink() + cleanup(torch_device, gc_collect=True) + + @slow + @require_torch_accelerator + def test_dia_model_integration_generate_tts(self): + text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"] + processor = DiaProcessor.from_pretrained(self.model_checkpoint) + inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device) + + model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device) + outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False) + + # fmt: off + EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026], + [ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026], + [ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026], + [ 568, 804, 10, 674, 364, 981, 728, 1026, 1026], + [ 568, 804, 10, 674, 364, 981, 741, 550, 1026], + [ 568, 804, 10, 674, 364, 981, 568, 378, 90], + [1024, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 804, 10, 674, 364, 981, 568, 378, 731], + [1025, 1024, 10, 674, 364, 981, 568, 378, 731], + [1025, 1025, 1024, 674, 364, 981, 568, 378, 731], + [1025, 1025, 1025, 1024, 364, 981, 568, 378, 731], + [1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731], + [1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731], + [1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]], + + [[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026], + [ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026], + [ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026], + [ 592, 402, 386, 347, 362, 981, 728, 1026, 1026], + [ 592, 402, 721, 728, 327, 981, 741, 550, 1026], + [ 592, 402, 721, 728, 460, 62, 676, 378, 90], + [1024, 402, 721, 728, 837, 595, 195, 982, 784], + [1025, 402, 721, 677, 497, 102, 692, 24, 330], + [1025, 402, 721, 677, 511, 102, 503, 871, 609], + [1025, 402, 721, 677, 511, 96, 801, 871, 894], + [1025, 402, 721, 677, 511, 745, 314, 498, 775], + [1025, 402, 721, 677, 511, 745, 314, 498, 105], + [1025, 402, 721, 677, 511, 745, 314, 861, 889], + [1025, 893, 721, 677, 511, 744, 314, 871, 353], + [1025, 1024, 888, 677, 511, 744, 314, 871, 332], + [1025, 1025, 1024, 518, 511, 744, 314, 871, 366], + [1025, 1025, 1025, 1024, 611, 744, 314, 871, 366], + [1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366], + [1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366], + [1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]]) + # fmt: on + + torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS) + + @slow + @require_torch_accelerator + def test_dia_model_integration_generate_audio_context(self): + text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"] + audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy() + audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy() + audio = [audio_sample_1, audio_sample_2] + + processor = DiaProcessor.from_pretrained(self.model_checkpoint) + inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device) + + model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device) + # dia has right padding while we have left padding (for faster prefill) + # additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings) + outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False) + outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False) + + # fmt: off + EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026], + [ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026], + [ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026], + [ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026], + [ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026], + [ 330, 567, 530, 753, 607, 179, 954, 242, 1026], + [ 330, 627, 6, 1010, 500, 189, 598, 858, 247], + [1024, 432, 480, 530, 122, 3, 788, 149, 814], + [1025, 875, 826, 458, 98, 540, 181, 122, 608], + [1025, 495, 840, 413, 337, 784, 591, 150, 1017], + [1025, 808, 189, 137, 445, 0, 227, 658, 345], + [1025, 397, 89, 753, 1016, 173, 984, 0, 910], + [1025, 875, 460, 934, 50, 335, 670, 818, 722], + [1025, 875, 460, 762, 119, 372, 503, 858, 584], + [1025, 348, 555, 475, 469, 458, 963, 41, 664], + [1025, 1024, 852, 683, 761, 193, 595, 895, 885], + [1025, 1025, 1024, 135, 761, 902, 163, 623, 385], + [1025, 1025, 1025, 1024, 852, 282, 581, 623, 70], + [1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977], + [1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464], + [1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]) + + EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026], + [ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026], + [ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026], + [ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026], + [ 315, 249, 678, 868, 899, 257, 950, 1026, 1026], + [ 315, 249, 217, 471, 292, 908, 196, 469, 1026], + [ 315, 249, 825, 771, 839, 802, 633, 590, 531], + [1024, 249, 150, 53, 126, 76, 794, 626, 442], + [1025, 249, 825, 218, 359, 864, 526, 626, 770], + [1025, 249, 150, 137, 530, 845, 877, 600, 111], + [1025, 249, 150, 287, 730, 991, 135, 259, 39], + [1025, 249, 825, 104, 198, 1020, 719, 625, 208], + [1025, 249, 825, 997, 602, 256, 859, 322, 518], + [1025, 668, 825, 979, 584, 256, 98, 665, 589], + [1025, 954, 458, 54, 206, 52, 244, 822, 599], + [1025, 1024, 104, 914, 435, 579, 860, 92, 661], + [1025, 1025, 1024, 848, 126, 74, 304, 92, 753], + [1025, 1025, 1025, 1024, 362, 376, 304, 586, 753], + [1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83], + [1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928], + [1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79], + [1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]) + # fmt: on + + torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1) + torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding diff --git a/tests/models/dia/test_processor_dia.py b/tests/models/dia/test_processor_dia.py new file mode 100644 index 0000000000..8ce15f4330 --- /dev/null +++ b/tests/models/dia/test_processor_dia.py @@ -0,0 +1,269 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from transformers import DacModel, DiaFeatureExtractor, DiaProcessor, DiaTokenizer +from transformers.testing_utils import require_torch +from transformers.utils import is_torch_available + + +if is_torch_available: + import torch + + +# Copied from tests.utils.test_modeling_utils.check_models_equal +def check_models_equal(model1, model2): + models_are_equal = True + for model1_p, model2_p in zip(model1.parameters(), model2.parameters()): + if model1_p.data.ne(model2_p.data).sum() > 0: + models_are_equal = False + + return models_are_equal + + +@require_torch +class DiaProcessorTest(unittest.TestCase): + def setUp(self): + self.checkpoint = "AntonV/Dia-1.6B" + self.audio_tokenizer_checkpoint = "descript/dac_44khz" + self.tmpdirname = tempfile.mkdtemp() + + # Audio tokenizer is a bigger model so we will reuse this if possible + self.processor = DiaProcessor( + tokenizer=self.get_tokenizer(), + feature_extractor=self.get_feature_extractor(), + audio_tokenizer=self.get_audio_tokenizer(), + ) + + # Default audio values based on Dia and Dac + self.pad_id = 1025 + self.bos_id = 1026 + self.dac_chunk_len = 512 + self.delay_pattern = [0, 8, 9, 10, 11, 12, 13, 14, 15] + + def get_tokenizer(self, **kwargs): + return DiaTokenizer.from_pretrained(self.checkpoint, **kwargs) + + def get_feature_extractor(self, **kwargs): + return DiaFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + + def get_audio_tokenizer(self, **kwargs): + return DacModel.from_pretrained(self.audio_tokenizer_checkpoint, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + del self.processor + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + audio_tokenizer = self.get_audio_tokenizer() + + processor = DiaProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer + ) + + processor.save_pretrained(self.tmpdirname) + processor = DiaProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, DiaTokenizer) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor) + + self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer.__class__.__name__) + self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer.name_or_path) + self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer)) + self.assertIsInstance(processor.audio_tokenizer, DacModel) + + def test_save_load_pretrained_additional_features(self): + processor = DiaProcessor( + tokenizer=self.get_tokenizer(), + feature_extractor=self.get_feature_extractor(), + audio_tokenizer=self.get_audio_tokenizer(), + ) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer() + feature_extractor_add_kwargs = self.get_feature_extractor() + audio_tokenizer_add_kwargs = self.get_audio_tokenizer() + + processor = DiaProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, DiaTokenizer) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor) + + self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer_add_kwargs.__class__.__name__) + self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer_add_kwargs.name_or_path) + self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer_add_kwargs)) + self.assertIsInstance(processor.audio_tokenizer, DacModel) + + def test_model_input_names(self): + tokenizer = self.get_tokenizer() + + self.assertListEqual( + self.processor.model_input_names, + list(dict.fromkeys(tokenizer.model_input_names + ["decoder_input_ids", "decoder_attention_mask"])), + msg="`processor` model input names do not match the expected names.", + ) + + def test_tokenize(self): + tokenizer = self.get_tokenizer() + random_text = ["This is a processing test for tokenization", "[S1] Dia template style [S2] Nice"] + + input_tokenizer = tokenizer(random_text, padding=True, return_tensors="pt") + input_processor = self.processor(random_text) + + for key in input_tokenizer.keys(): + self.assertTrue((input_tokenizer[key] == input_processor[key]).all()) + + def test_no_audio(self): + random_text = ["Dummy Input"] * 2 + input_processor = self.processor(random_text) + audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"] + + # full mask with +1 for bos + self.assertTrue(audio_mask.sum() == (max(self.delay_pattern) + 1) * len(random_text)) + self.assertTrue( + audio_tokens.shape + == ( + len(random_text), + max(self.delay_pattern) + 1, + len(self.delay_pattern), + ) + ) + + for channel_idx, delay in enumerate(self.delay_pattern): + expected_sequence = torch.ones(size=(audio_tokens.shape[:-1])) * self.pad_id + expected_sequence[:, : delay + 1] = self.bos_id + self.assertTrue((audio_tokens[..., channel_idx] == expected_sequence).all()) + + def test_audio(self): + audio_tokenizer = self.get_audio_tokenizer() + feature_extractor = self.get_feature_extractor() + + random_text = ["Dummy Input"] * 2 + # Dac only starts accepting audio from a certain length (ensured via >=1024) + raw_speeches = [np.random.rand(2048).astype(np.float32), np.random.rand(1024).astype(np.float32)] + input_processor = self.processor(random_text, raw_speeches) + audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"] + + sequence_len = audio_mask.shape[1] + for batch_idx, speech in enumerate(raw_speeches): + raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"] + codebooks = audio_tokenizer(raw_audio).audio_codes.transpose(1, 2) + + pad_len = sequence_len - audio_mask.sum(dim=-1)[batch_idx] + for channel_idx, delay in enumerate(self.delay_pattern): + # Left padding filled bos, right padding (delay) are pad + start_idx = pad_len + delay + 1 + end_idx = start_idx + codebooks.shape[1] + + encoded_sequence = audio_tokens[batch_idx, :, channel_idx] + expected_sequence = torch.ones(size=(sequence_len,)) * self.pad_id + expected_sequence[:start_idx] = self.bos_id + expected_sequence[start_idx:end_idx] = codebooks[0, :, channel_idx] + + self.assertTrue((encoded_sequence == expected_sequence).all()) + + # Just to make sure the masking correctly only ignores bos tokens + self.assertTrue((audio_tokens[~audio_mask.bool()] == self.bos_id).all()) + + @parameterized.expand([([1, 1],), ([1, 5],), ([2, 4, 6],)]) + def test_decode_audio(self, audio_lens): + feature_extractor = self.get_feature_extractor() + audio_tokenizer = self.get_audio_tokenizer() + + random_text = ["Dummy Input"] * len(audio_lens) + raw_speeches = [np.random.rand(self.dac_chunk_len * l).astype(np.float32) for l in audio_lens] + # we need eos (given if training) to decode properly, also enforced via custom logits processor + input_processor = self.processor(random_text, raw_speeches, generation=False) + audio_tokens = input_processor["decoder_input_ids"] + + decoded_speeches = self.processor.batch_decode(audio_tokens) + for batch_idx, speech in enumerate(raw_speeches): + raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"] + codebooks = audio_tokenizer(raw_audio).audio_codes + + decoded_audio = decoded_speeches[batch_idx] + expected_audio = audio_tokenizer.decode(audio_codes=codebooks).audio_values + + self.assertTrue((expected_audio == decoded_audio).all()) + self.assertTrue(decoded_speeches[batch_idx].shape[-1] == audio_lens[batch_idx] * self.dac_chunk_len) + + @parameterized.expand([(1, 2, [0, 1, 4]), (2, 4, [1, 3, 2]), (4, 8, [0, 5, 7])]) + def test_delay_in_audio(self, bsz, seq_len, delay_pattern): + # static functions which are crucial, hence we also test them here + build_indices_fn = DiaProcessor.build_indices + delay_fn = DiaProcessor.apply_audio_delay + + bos, pad = -2, -1 + num_channels = len(delay_pattern) + + audio_input = torch.arange(bsz * seq_len * num_channels).view(bsz, seq_len, num_channels) + # imitate a delay mask with zeroes + audio_input = torch.cat([audio_input, torch.zeros(size=(bsz, max(delay_pattern), num_channels))], dim=1) + + precomputed_idx = build_indices_fn( + bsz=bsz, + seq_len=seq_len + max(delay_pattern), + num_channels=num_channels, + delay_pattern=delay_pattern, + revert=False, + ) + delayed_audio_out = delay_fn( + audio=audio_input, + pad_token_id=pad, + bos_token_id=bos, + precomputed_idx=precomputed_idx, + ) + + # every channel idx is shifted by delay_pattern[idx] + delayed_audio_res = audio_input.clone() + for idx, delay in enumerate(delay_pattern): + delayed_audio_res[:, :delay, idx] = bos + remaining_input = seq_len + max(delay_pattern) - delay + delayed_audio_res[:, delay:, idx] = audio_input[:, :remaining_input, idx] + + self.assertTrue((delayed_audio_out == delayed_audio_res).all()) + + # we should get back to the original audio we had (when removing the delay pad) + bsz, new_seq_len, num_channels = delayed_audio_out.shape + precomputed_idx = build_indices_fn( + bsz=bsz, + seq_len=new_seq_len, + num_channels=num_channels, + delay_pattern=delay_pattern, + revert=True, + ) + reverted_audio_out = delay_fn( + audio=delayed_audio_out, + pad_token_id=pad, + bos_token_id=bos, + precomputed_idx=precomputed_idx, + ) + + reverted_audio_res = audio_input.clone()[:, :seq_len] + + self.assertTrue((reverted_audio_out[:, :seq_len] == reverted_audio_res).all()) diff --git a/tests/models/dia/test_tokenization_dia.py b/tests/models/dia/test_tokenization_dia.py new file mode 100644 index 0000000000..4ade611f68 --- /dev/null +++ b/tests/models/dia/test_tokenization_dia.py @@ -0,0 +1,123 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.models.dia import DiaTokenizer +from transformers.testing_utils import slow + +from ...test_tokenization_common import TokenizerTesterMixin + + +# Special tokens +PAD = 0 +S1 = 1 +S2 = 2 + + +class DiaTokenizerTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = DiaTokenizer + test_rust_tokenizer = False + + @classmethod + def setUpClass(cls): + super().setUpClass() + tokenizer = DiaTokenizer() + tokenizer.save_pretrained(cls.tmpdirname) + + def test_convert_token_and_id(self): + """Test ``_convert_token_to_id`` and ``_convert_id_to_token``.""" + token = "i" + token_id = 105 + + self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) + self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) + + def test_get_vocab(self): + vocab_keys = list(self.get_tokenizer().get_vocab().keys()) + + self.assertEqual(vocab_keys[PAD], "") + self.assertEqual(vocab_keys[S1], "[S1]") + self.assertEqual(vocab_keys[S2], "[S2]") + self.assertEqual(len(vocab_keys), 256) + + def test_vocab_size(self): + # utf-8 == 2**8 == 256 + self.assertEqual(self.get_tokenizer().vocab_size, 256) + + def test_full_tokenizer(self): + tokenizer = DiaTokenizer.from_pretrained(self.tmpdirname) + + tokens = tokenizer.tokenize("Hello, world!") + self.assertListEqual(tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"]) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33]) + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual(back_tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"]) + + tokens = tokenizer.tokenize("[S1] Hello [S2] Hello") + self.assertListEqual( + tokens, + ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", ""], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [S1, 32, 72, 101, 108, 108, 111, 32, S2, 32, 72, 101, 108, 108, 111, PAD]) + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", ""] + ) + + @slow + def test_tokenizer_integration(self): + # Overwritten as decoding will lead to all single bytes (i.e. characters) while usually the string format is expected + expected_encoding = {'input_ids': [[84, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 40, 102, 111, 114, 109, 101, 114, 108, 121, 32, 107, 110, 111, 119, 110, 32, 97, 115, 32, 112, 121, 116, 111, 114, 99, 104, 45, 116, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 97, 110, 100, 32, 112, 121, 116, 111, 114, 99, 104, 45, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 45, 98, 101, 114, 116, 41, 32, 112, 114, 111, 118, 105, 100, 101, 115, 32, 103, 101, 110, 101, 114, 97, 108, 45, 112, 117, 114, 112, 111, 115, 101, 32, 97, 114, 99, 104, 105, 116, 101, 99, 116, 117, 114, 101, 115, 32, 40, 66, 69, 82, 84, 44, 32, 71, 80, 84, 45, 50, 44, 32, 82, 111, 66, 69, 82, 84, 97, 44, 32, 88, 76, 77, 44, 32, 68, 105, 115, 116, 105, 108, 66, 101, 114, 116, 44, 32, 88, 76, 78, 101, 116, 46, 46, 46, 41, 32, 102, 111, 114, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 85, 110, 100, 101, 114, 115, 116, 97, 110, 100, 105, 110, 103, 32, 40, 78, 76, 85, 41, 32, 97, 110, 100, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 71, 101, 110, 101, 114, 97, 116, 105, 111, 110, 32, 40, 78, 76, 71, 41, 32, 119, 105, 116, 104, 32, 111, 118, 101, 114, 32, 51, 50, 43, 32, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 32, 109, 111, 100, 101, 108, 115, 32, 105, 110, 32, 49, 48, 48, 43, 32, 108, 97, 110, 103, 117, 97, 103, 101, 115, 32, 97, 110, 100, 32, 100, 101, 101, 112, 32, 105, 110, 116, 101, 114, 111, 112, 101, 114, 97, 98, 105, 108, 105, 116, 121, 32, 98, 101, 116, 119, 101, 101, 110, 32, 74, 97, 120, 44, 32, 80, 121, 84, 111, 114, 99, 104, 32, 97, 110, 100, 32, 84, 101, 110, 115, 111, 114, 70, 108, 111, 119, 46], [66, 69, 82, 84, 32, 105, 115, 32, 100, 101, 115, 105, 103, 110, 101, 100, 32, 116, 111, 32, 112, 114, 101, 45, 116, 114, 97, 105, 110, 32, 100, 101, 101, 112, 32, 98, 105, 100, 105, 114, 101, 99, 116, 105, 111, 110, 97, 108, 32, 114, 101, 112, 114, 101, 115, 101, 110, 116, 97, 116, 105, 111, 110, 115, 32, 102, 114, 111, 109, 32, 117, 110, 108, 97, 98, 101, 108, 101, 100, 32, 116, 101, 120, 116, 32, 98, 121, 32, 106, 111, 105, 110, 116, 108, 121, 32, 99, 111, 110, 100, 105, 116, 105, 111, 110, 105, 110, 103, 32, 111, 110, 32, 98, 111, 116, 104, 32, 108, 101, 102, 116, 32, 97, 110, 100, 32, 114, 105, 103, 104, 116, 32, 99, 111, 110, 116, 101, 120, 116, 32, 105, 110, 32, 97, 108, 108, 32, 108, 97, 121, 101, 114, 115, 46], [84, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 115, 32, 111, 118, 101, 114, 32, 116, 104, 101, 32, 108, 97, 122, 121, 32, 100, 111, 103, 46]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip + + sequences = [ + "Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides " + "general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural " + "Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained " + "models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.", + "BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly " + "conditioning on both left and right context in all layers.", + "The quick brown fox jumps over the lazy dog.", + ] + + tokenizer_classes = [self.tokenizer_class] + if self.test_rust_tokenizer: + tokenizer_classes.append(self.rust_tokenizer_class) + + for tokenizer_class in tokenizer_classes: + tokenizer = tokenizer_class.from_pretrained("AntonV/Dia-1.6B") + + encoding = tokenizer(sequences) + encoding_data = encoding.data + self.assertDictEqual(encoding_data, expected_encoding) + + # Byte decoding leads to characters so we need to join them + decoded_sequences = [ + "".join(tokenizer.decode(seq, skip_special_tokens=True)) for seq in encoding["input_ids"] + ] + + for expected, decoded in zip(sequences, decoded_sequences): + if self.test_sentencepiece_ignore_case: + expected = expected.lower() + self.assertEqual(expected, decoded) + + @unittest.skip(reason="Dia relies on whole input string due to the byte-level nature.") + def test_pretokenized_inputs(self): + pass + + @unittest.skip + def test_tokenizer_slow_store_full_signature(self): + pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a5d9c90068..d3f8456f54 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4574,6 +4574,11 @@ class ModelTesterMixin: head_dim = config.head_dim config.head_dim = max(16, config.head_dim) + cross_head_dim = None + if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None: + cross_head_dim = config.cross_head_dim + config.cross_head_dim = max(16, config.cross_head_dim) + if ( getattr(config, "hidden_size", None) is not None and getattr(config, "num_attention_heads", None) is not None @@ -4588,6 +4593,17 @@ class ModelTesterMixin: decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads config.decoder_hidden_size *= max(16 // decoder_head_dim, 1) + if ( + getattr(config, "cross_hidden_size", None) is not None + and getattr(config, "cross_num_attention_heads", None) is not None + ): + cross_head_dim = ( + cross_head_dim + if cross_head_dim is not None + else config.cross_hidden_size // config.cross_num_attention_heads + ) + config.cross_hidden_size *= max(16 // cross_head_dim, 1) + # Set default attention to flex and update config values update_config_for_flex(config) for key in config.sub_configs: diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 46c2bb1a9f..22d6b033af 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -32,6 +32,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS) CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING SPECIAL_CASES_TO_ALLOW = { + # used internally during generation to provide the custom logit processors with their necessary information + "DiaConfig": [ + "delay_pattern", + ], # 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264). # periods and offsets are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`. "BambaConfig": [