diff --git a/README.md b/README.md index 2a8e515e14..cc1fd458f5 100644 --- a/README.md +++ b/README.md @@ -411,6 +411,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari. 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen. 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi. 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. diff --git a/README_es.md b/README_es.md index f5077993aa..fcb6049870 100644 --- a/README_es.md +++ b/README_es.md @@ -386,6 +386,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari. 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen. 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi. 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. diff --git a/README_hd.md b/README_hd.md index 146af0585c..9b694ed607 100644 --- a/README_hd.md +++ b/README_hd.md @@ -358,6 +358,7 @@ conda install -c huggingface transformers 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple से) Sachin Mehta and Mohammad Rastegari. द्वाराअनुसंधान पत्र [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) के साथ जारी किया गया 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI से) साथ वाला पेपर [mT5: एक व्यापक बहुभाषी पूर्व-प्रशिक्षित टेक्स्ट-टू-टेक्स्ट ट्रांसफॉर्मर]( https://arxiv.org/abs/2010.11934) लिंटिंग ज़ू, नोआ कॉन्सटेंट, एडम रॉबर्ट्स, मिहिर काले, रामी अल-रफू, आदित्य सिद्धांत, आदित्य बरुआ, कॉलिन रैफेल द्वारा पोस्ट किया गया। +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen. 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi. 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (हुआवेई नूह के आर्क लैब से) साथ में कागज़ [NEZHA: चीनी भाषा समझ के लिए तंत्रिका प्रासंगिक प्रतिनिधित्व](https :/ /arxiv.org/abs/1909.00204) जुन्किउ वेई, ज़ियाओज़े रेन, ज़िआओगुआंग ली, वेनयोंग हुआंग, यी लियाओ, याशेंग वांग, जियाशू लिन, शिन जियांग, जिओ चेन और कुन लियू द्वारा। diff --git a/README_ja.md b/README_ja.md index 5584a84fac..60b14191c1 100644 --- a/README_ja.md +++ b/README_ja.md @@ -420,6 +420,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple から) Sachin Mehta and Mohammad Rastegari. から公開された研究論文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research から) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu から公開された研究論文: [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI から) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel から公開された研究論文: [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box から) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen から公開された研究論文: [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (SHI Labs から) Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi から公開された研究論文: [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab から) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu から公開された研究論文: [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) diff --git a/README_ko.md b/README_ko.md index a476f4493b..cdbeec9a4b 100644 --- a/README_ko.md +++ b/README_ko.md @@ -335,6 +335,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple 에서 제공)은 Sachin Mehta and Mohammad Rastegari.의 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)논문과 함께 발표했습니다. 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research 에서) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 의 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 논문과 함께 발표했습니다. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI 에서) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 의 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 논문과 함께 발표했습니다. +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box 에서) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 의 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 논문과 함께 발표했습니다. 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (SHI Labs 에서) Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi 의 [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 논문과 함께 발표했습니다. 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab 에서) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 의 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index bdc661091a..db1cbd4237 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -359,6 +359,7 @@ conda install -c huggingface transformers 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (来自 Apple) 伴随论文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) 由 Sachin Mehta and Mohammad Rastegari 发布。 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。 +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (来自 中国人民大学 AI Box) 伴随论文 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 由 Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 发布。 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (来自 SHI Labs) 伴随论文 [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 由 Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi 发布。 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (来自华为诺亚方舟实验室) 伴随论文 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 由 Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 94a275dc21..e66cd1f867 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -371,6 +371,7 @@ conda install -c huggingface transformers 1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari. 1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. +1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen. 1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi. 1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 414e67ffe7..e36282b21c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -549,6 +549,8 @@ title: MCTCT - local: model_doc/mms title: MMS + - local: model_doc/musicgen + title: MusicGen - local: model_doc/sew title: SEW - local: model_doc/sew-d diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 4c7f238815..91c57e0f39 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -175,6 +175,7 @@ The documentation is organized into five sections: 1. **[MobileViTV2](model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari. 1. **[MPNet](model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. 1. **[MT5](model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. +1. **[MusicGen](model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. 1. **[MVP](model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen. 1. **[NAT](model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi. 1. **[Nezha](model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu. @@ -380,6 +381,7 @@ Flax), PyTorch, and/or TensorFlow. | MobileViTV2 | ❌ | ❌ | ✅ | ❌ | ❌ | | MPNet | ✅ | ✅ | ✅ | ✅ | ❌ | | MT5 | ✅ | ✅ | ✅ | ✅ | ✅ | +| MusicGen | ❌ | ❌ | ✅ | ❌ | ❌ | | MVP | ✅ | ✅ | ✅ | ❌ | ❌ | | NAT | ❌ | ❌ | ✅ | ❌ | ❌ | | Nezha | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/musicgen.md b/docs/source/en/model_doc/musicgen.md new file mode 100644 index 0000000000..72250a86fc --- /dev/null +++ b/docs/source/en/model_doc/musicgen.md @@ -0,0 +1,277 @@ + + +# MusicGen + +## Overview + +The MusicGen model was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) +by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez. + +MusicGen is a single stage auto-regressive Transformer model capable of generating high-quality music samples conditioned +on text descriptions or audio prompts. The text descriptions are passed through a frozen text encoder model to obtain a +sequence of hidden-state representations. MusicGen is then trained to predict discrete audio tokens, or *audio codes*, +conditioned on these hidden-states. These audio tokens are then decoded using an audio compression model, such as EnCodec, +to recover the audio waveform. + +Through an efficient token interleaving pattern, MusicGen does not require a self-supervised semantic representation of +the text/audio prompts, thus eliminating the need to cascade multiple models to predict a set of codebooks (e.g. +hierarchically or upsampling). Instead, it is able to generate all the codebooks in a single forward pass. + +The abstract from the paper is the following: + +*We tackle the task of conditional music generation. We introduce MusicGen, a single Language Model (LM) that operates +over several streams of compressed discrete music representation, i.e., tokens. Unlike prior work, MusicGen is comprised +of a single-stage transformer LM together with efficient token interleaving patterns, which eliminates the need for +cascading several models, e.g., hierarchically or upsampling. Following this approach, we demonstrate how MusicGen +can generate high-quality samples, while being conditioned on textual description or melodic features, allowing better +controls over the generated output. We conduct extensive empirical evaluation, considering both automatic and human +studies, showing the proposed approach is superior to the evaluated baselines on a standard text-to-music benchmark. +Through ablation studies, we shed light over the importance of each of the components comprising MusicGen.* + +This model was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original code can be found +[here](https://github.com/facebookresearch/audiocraft). The pre-trained checkpoints can be found on the +[Hugging Face Hub](https://huggingface.co/models?sort=downloads&search=facebook%2Fmusicgen-). + +## Generation + +MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly +better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default, +and can be explicitly specified by setting `do_sample=True` in the call to [`MusicgenForConditionalGeneration.generate`], +or by overriding the model's generation config (see below). + +### Unconditional Generation + +The inputs for unconditional (or 'null') generation can be obtained through the method +[`MusicgenForConditionalGeneration.get_unconditional_inputs`]: + +```python +>>> from transformers import MusicgenForConditionalGeneration + +>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") +>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + +>>> audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256) +``` + +The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen +to the generated audio samples, you can either play them in an ipynb notebook: + +```python +from IPython.display import Audio + +sampling_rate = model.config.audio_encoder.sampling_rate +Audio(audio_values[0].numpy(), rate=sampling_rate) +``` + +Or save them as a `.wav` file using a third-party library, e.g. `scipy`: + +```python +>>> import scipy + +>>> sampling_rate = model.config.audio_encoder.sampling_rate +>>> scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy()) +``` + +### Text-Conditional Generation + +The model can generate an audio sample conditioned on a text prompt through use of the [`MusicgenProcessor`] to pre-process +the inputs: + +```python +>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration + +>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") +>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + +>>> inputs = processor( +... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], +... padding=True, +... return_tensors="pt", +... ) +>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256) +``` + +The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits +(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or +'null' prompt). Higher guidance scale encourages the model to generate samples that are more closely linked to the input +prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results, +use `guidance_scale=3` (default). + +### Audio-Prompted Generation + +The same [`MusicgenProcessor`] can be used to pre-process an audio prompt that is used for audio continuation. In the +following example, we load an audio file using the 🤗 Datasets library, which can be pip installed through the command +below: + +``` +pip install --upgrade pip +pip install datasets[audio] +``` + +```python +>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration +>>> from datasets import load_dataset + +>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") +>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + +>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True) +>>> sample = next(iter(dataset))["audio"] + +>>> # take the first half of the audio sample +>>> sample["array"] = sample["array"][: len(sample["array"]) // 2] + +>>> inputs = processor( +... audio=sample["array"], +... sampling_rate=sample["sampling_rate"], +... text=["80s blues track with groovy saxophone"], +... padding=True, +... return_tensors="pt", +... ) +>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256) +``` + +For batched audio-prompted generation, the generated `audio_values` can be post-processed to remove padding by using the +[`MusicgenProcessor`] class: + +```python +>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration +>>> from datasets import load_dataset + +>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") +>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + +>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True) +>>> sample = next(iter(dataset))["audio"] + +>>> # take the first quarter of the audio sample +>>> sample_1 = sample["array"][: len(sample["array"]) // 4] + +>>> # take the first half of the audio sample +>>> sample_2 = sample["array"][: len(sample["array"]) // 2] + +>>> inputs = processor( +... audio=[sample_1, sample_2], +... sampling_rate=sample["sampling_rate"], +... text=["80s blues track with groovy saxophone", "90s rock song with loud guitars and heavy drums"], +... padding=True, +... return_tensors="pt", +... ) +>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256) + +>>> # post-process to remove padding from the batched audio +>>> audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask) +``` + +### Generation Configuration + +The default parameters that control the generation process, such as sampling, guidance scale and number of generated +tokens, can be found in the model's generation config, and updated as desired: + +```python +>>> from transformers import MusicgenForConditionalGeneration + +>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + +>>> # inspect the default generation config +>>> model.generation_config + +>>> # increase the guidance scale to 4.0 +>>> model.generation_config.guidance_scale = 4.0 + +>>> # decrease the max length to 256 tokens +>>> model.generation_config.max_length = 256 +``` + +Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting +`do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the +generation config. + +## Model Structure + +The MusicGen model can be de-composed into three distinct stages: +1. Text encoder: maps the text inputs to a sequence of hidden-state representations. The pre-trained MusicGen models use a frozen text encoder from either T5 or Flan-T5 +2. MusicGen decoder: a language model (LM) that auto-regressively generates audio tokens (or codes) conditional on the encoder hidden-state representations +3. Audio encoder/decoder: used to encode an audio prompt to use as prompt tokens, and recover the audio waveform from the audio tokens predicted by the decoder + +Thus, the MusicGen model can either be used as a standalone decoder model, corresponding to the class [`MusicgenForCausalLM`], +or as a composite model that includes the text encoder and audio encoder/decoder, corresponding to the class +[`MusicgenForConditionalGeneration`]. + +Since the text encoder and audio encoder/decoder models are frozen during training, the MusicGen decoder [`MusicgenForCausalLM`] +can be trained standalone on a dataset of encoder hidden-states and audio codes. For inference, the trained decoder can +be combined with the frozen text encoder and audio encoder/decoders to recover the composite [`MusicgenForConditionalGeneration`] +model. + +Below, we demonstrate how to construct the composite [`MusicgenForConditionalGeneration`] model from its three constituent +parts, as would typically be done following training of the MusicGen decoder LM: + +```python +>>> from transformers import AutoConfig, AutoModelForTextEncoding, AutoModel, MusicgenForCausalLM, MusicgenForConditionalGeneration + +>>> text_encoder = AutoModelForTextEncoding.from_pretrained("t5-base") +>>> audio_encoder = AutoModel.from_pretrained("facebook/encodec_32khz") +>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder +>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config) + +>>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(text_encoder, audio_encoder, decoder) +``` + +If only the decoder needs to be loaded from the pre-trained checkpoint for the composite model, it can be loaded by first +specifying the correct config, or be accessed through the `.decoder` attribute of the composite model: + +```python +>>> from transformers import AutoConfig, MusicgenForCausalLM, MusicgenForConditionalGeneration + +>>> # Option 1: get decoder config and pass to `.from_pretrained` +>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder +>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config) + +>>> # Option 2: load the entire composite model, but only return the decoder +>>> decoder = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").decoder +``` + +Tips: +* MusicGen is trained on the 32kHz checkpoint of Encodec. You should ensure you use a compatible version of the Encodec model. +* Sampling mode tends to deliver better results than greedy - you can toggle sampling with the variable `do_sample` in the call to [`MusicgenForConditionalGeneration.generate`] + +## MusicgenDecoderConfig + +[[autodoc]] MusicgenDecoderConfig + +## MusicgenConfig + +[[autodoc]] MusicgenConfig + +## MusicgenProcessor + +[[autodoc]] MusicgenProcessor + +## MusicgenModel + +[[autodoc]] MusicgenModel + - forward + +## MusicgenForCausalLM + +[[autodoc]] MusicgenForCausalLM + - forward + +## MusicgenForConditionalGeneration + +[[autodoc]] MusicgenForConditionalGeneration + - forward diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md index f6f9b37afe..4986d90b28 100644 --- a/docs/source/en/tasks/language_modeling.md +++ b/docs/source/en/tasks/language_modeling.md @@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the Choose one of the following architectures: -[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) +[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 202d9a6101..3747e1951f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -402,6 +402,11 @@ _import_structure = { "models.mobilevitv2": ["MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTV2Config"], "models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], "models.mt5": ["MT5Config"], + "models.musicgen": [ + "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP", + "MusicgenConfig", + "MusicgenDecoderConfig", + ], "models.mvp": ["MvpConfig", "MvpTokenizer"], "models.nat": ["NAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "NatConfig"], "models.nezha": ["NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP", "NezhaConfig"], @@ -2134,6 +2139,16 @@ else: _import_structure["models.mt5"].extend( ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"] ) + _import_structure["models.musicgen"].extend( + [ + "MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST", + "MusicgenForCausalLM", + "MusicgenForConditionalGeneration", + "MusicgenModel", + "MusicgenPreTrainedModel", + "MusicgenProcessor", + ] + ) _import_structure["models.mvp"].extend( [ "MVP_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4249,6 +4264,11 @@ if TYPE_CHECKING: from .models.mobilevitv2 import MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTV2Config from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer from .models.mt5 import MT5Config + from .models.musicgen import ( + MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP, + MusicgenConfig, + MusicgenDecoderConfig, + ) from .models.mvp import MvpConfig, MvpTokenizer from .models.nat import NAT_PRETRAINED_CONFIG_ARCHIVE_MAP, NatConfig from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig @@ -5709,6 +5729,14 @@ if TYPE_CHECKING: MT5Model, MT5PreTrainedModel, ) + from .models.musicgen import ( + MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST, + MusicgenForCausalLM, + MusicgenForConditionalGeneration, + MusicgenModel, + MusicgenPreTrainedModel, + MusicgenProcessor, + ) from .models.mvp import ( MVP_PRETRAINED_MODEL_ARCHIVE_LIST, MvpForCausalLM, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 99426790db..4514ccef76 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -185,6 +185,10 @@ class GenerationConfig(PushToHubMixin): Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the sequence being selected, while negative biases do the opposite. Check [`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples. + guidance_scale (`float`, *optional*): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. > Parameters that define the output variables of `generate` @@ -265,6 +269,7 @@ class GenerationConfig(PushToHubMixin): self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.sequence_bias = kwargs.pop("sequence_bias", None) + self.guidance_scale = kwargs.pop("guidance_scale", None) # Parameters that define the output variables of `generate` self.num_return_sequences = kwargs.pop("num_return_sequences", 1) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 6a73fcf964..a7fad67845 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1065,3 +1065,40 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): scores[k, : self.timestamp_begin] = -float("inf") return scores + + +class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension, + where the first half correspond to the conditional logits (predicted from the input prompt) and the second half + correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a + weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. + + Args: + guidance_scale (float): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. + """ + + def __init__(self, guidance_scale): + if guidance_scale > 1: + self.guidance_scale = guidance_scale + else: + raise ValueError( + "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale " + f"{guidance_scale}." + ) + + def __call__(self, input_ids, scores): + # simple check to make sure we have compatible batch sizes between our + # logits scores (cond + uncond) and input ids (cond only) + if scores.shape[0] != 2 * input_ids.shape[0]: + raise ValueError( + f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to " + f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got " + f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids." + ) + unguided_bsz = scores.shape[0] // 2 + cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0) + scores = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale + return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cb681cc720..69120ab8f8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -38,6 +38,7 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .configuration_utils import GenerationConfig from .logits_process import ( + ClassifierFreeGuidanceLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, @@ -940,6 +941,8 @@ class GenerationMixin: ) if generation_config.forced_decoder_ids is not None: processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) + if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: + processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 771ba9be09..d8345c9ef8 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -135,6 +135,7 @@ from . import ( mobilevitv2, mpnet, mt5, + musicgen, mvp, nat, nezha, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2f2714e619..5cbaa0705a 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -138,6 +138,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("mobilevitv2", "MobileViTV2Config"), ("mpnet", "MPNetConfig"), ("mt5", "MT5Config"), + ("musicgen", "MusicgenConfig"), ("mvp", "MvpConfig"), ("nat", "NatConfig"), ("nezha", "NezhaConfig"), @@ -327,6 +328,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ("mobilevit", "MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mobilevitv2", "MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("musicgen", "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mvp", "MVP_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("nat", "NAT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("nezha", "NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -532,6 +534,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("mobilevitv2", "MobileViTV2"), ("mpnet", "MPNet"), ("mt5", "MT5"), + ("musicgen", "MusicGen"), ("mvp", "MVP"), ("nat", "NAT"), ("nezha", "Nezha"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8c5e5e8c15..8bb6ea37aa 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -389,6 +389,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("mbart", "MBartForCausalLM"), ("mega", "MegaForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"), + ("musicgen", "MusicgenForCausalLM"), ("mvp", "MvpForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), diff --git a/src/transformers/models/musicgen/__init__.py b/src/transformers/models/musicgen/__init__.py new file mode 100644 index 0000000000..7fa695eba8 --- /dev/null +++ b/src/transformers/models/musicgen/__init__.py @@ -0,0 +1,67 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_musicgen": [ + "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP", + "MusicgenConfig", + "MusicgenDecoderConfig", + ], + "processing_musicgen": ["MusicgenProcessor"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_musicgen"] = [ + "MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST", + "MusicgenForConditionalGeneration", + "MusicgenForCausalLM", + "MusicgenModel", + "MusicgenPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_musicgen import ( + MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP, + MusicgenConfig, + MusicgenDecoderConfig, + ) + from .processing_musicgen import MusicgenProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_musicgen import ( + MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST, + MusicgenForCausalLM, + MusicgenForConditionalGeneration, + MusicgenModel, + MusicgenPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py new file mode 100644 index 0000000000..2882c49f75 --- /dev/null +++ b/src/transformers/models/musicgen/configuration_musicgen.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MusicGen model configuration""" +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig + + +logger = logging.get_logger(__name__) + +MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/musicgen-small": "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json", + # See all Musicgen models at https://huggingface.co/models?filter=musicgen +} + + +class MusicgenDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a + MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MusicGen + [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 2048): + Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`MusicgenDecoder`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of decoder layers. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer block. + ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically, set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_factor (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(hidden_size). + use_cache (`bool`, *optional*, defaults to `True`): + Whether the model should return the last key/values attentions (not used by all models) + num_codebooks (`int`, *optional*, defaults to 4): + The number of parallel codebooks forwarded to the model. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether input and output word embeddings should be tied. + """ + model_type = "musicgen_decoder" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=2048, + max_position_embeddings=2048, + num_hidden_layers=24, + ffn_dim=4096, + num_attention_heads=16, + layerdrop=0.0, + use_cache=True, + activation_function="gelu", + hidden_size=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + initializer_factor=0.02, + scale_embedding=False, + num_codebooks=4, + pad_token_id=2048, + bos_token_id=2048, + eos_token_id=None, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.initializer_factor = initializer_factor + self.layerdrop = layerdrop + self.use_cache = use_cache + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.num_codebooks = num_codebooks + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class MusicgenConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a + MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder + configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + + - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the text encoder config. + - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that + defines the audio encoder config. + - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the decoder config. + + Example: + + ```python + >>> from transformers import ( + ... MusicgenConfig, + ... MusicgenDecoderConfig, + ... T5Config, + ... EncodecConfig, + ... MusicgenForConditionalGeneration, + ... ) + + >>> # Initializing text encoder, audio encoder, and decoder model configurations + >>> text_encoder_config = T5Config() + >>> audio_encoder_config = EncodecConfig() + >>> decoder_config = MusicgenDecoderConfig() + + >>> configuration = MusicgenConfig.from_sub_models_config( + ... text_encoder_config, audio_encoder_config, decoder_config + ... ) + + >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration + >>> model = MusicgenForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + >>> config_text_encoder = model.config.text_encoder + >>> config_audio_encoder = model.config.audio_encoder + >>> config_decoder = model.config.decoder + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("musicgen-model") + + >>> # loading model and config from pretrained folder + >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model") + >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config) + ```""" + + model_type = "musicgen" + is_composition = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: + raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") + + text_encoder_config = kwargs.pop("text_encoder") + text_encoder_model_type = text_encoder_config.pop("model_type") + + audio_encoder_config = kwargs.pop("audio_encoder") + audio_encoder_model_type = audio_encoder_config.pop("model_type") + + decoder_config = kwargs.pop("decoder") + + self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) + self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) + self.decoder = MusicgenDecoderConfig(**decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_sub_models_config( + cls, + text_encoder_config: PretrainedConfig, + audio_encoder_config: PretrainedConfig, + decoder_config: MusicgenDecoderConfig, + **kwargs, + ): + r""" + Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder + configurations. + + Returns: + [`MusicgenConfig`]: An instance of a configuration object + """ + + return cls( + text_encoder=text_encoder_config.to_dict(), + audio_encoder=audio_encoder_config.to_dict(), + decoder=decoder_config.to_dict(), + **kwargs, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_encoder"] = self.text_encoder.to_dict() + output["audio_encoder"] = self.audio_encoder.to_dict() + output["decoder"] = self.decoder.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/musicgen/convert_musicgen_transformers.py b/src/transformers/models/musicgen/convert_musicgen_transformers.py new file mode 100644 index 0000000000..517f0099d0 --- /dev/null +++ b/src/transformers/models/musicgen/convert_musicgen_transformers.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert MusicGen checkpoints from the original repository.""" +import argparse +from pathlib import Path +from typing import Dict, OrderedDict, Tuple + +import torch +from audiocraft.models import MusicGen + +from transformers import ( + AutoFeatureExtractor, + AutoTokenizer, + EncodecModel, + MusicgenDecoderConfig, + MusicgenForConditionalGeneration, + MusicgenProcessor, + T5EncoderModel, +) +from transformers.models.musicgen.modeling_musicgen import MusicgenForCausalLM +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"] + + +def rename_keys(name): + if "emb" in name: + name = name.replace("emb", "model.decoder.embed_tokens") + if "transformer" in name: + name = name.replace("transformer", "model.decoder") + if "cross_attention" in name: + name = name.replace("cross_attention", "encoder_attn") + if "linear1" in name: + name = name.replace("linear1", "fc1") + if "linear2" in name: + name = name.replace("linear2", "fc2") + if "norm1" in name: + name = name.replace("norm1", "self_attn_layer_norm") + if "norm_cross" in name: + name = name.replace("norm_cross", "encoder_attn_layer_norm") + if "norm2" in name: + name = name.replace("norm2", "final_layer_norm") + if "out_norm" in name: + name = name.replace("out_norm", "model.decoder.layer_norm") + if "linears" in name: + name = name.replace("linears", "lm_heads") + if "condition_provider.conditioners.description.output_proj" in name: + name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj") + return name + + +def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]: + """Function that takes the fairseq Musicgen state dict and renames it according to the HF + module names. It further partitions the state dict into the decoder (LM) state dict, and that for the + encoder-decoder projection.""" + keys = list(state_dict.keys()) + enc_dec_proj_state_dict = {} + for key in keys: + val = state_dict.pop(key) + key = rename_keys(key) + if "in_proj_weight" in key: + # split fused qkv proj + state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :] + state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :] + state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :] + elif "enc_to_dec_proj" in key: + enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val + else: + state_dict[key] = val + return state_dict, enc_dec_proj_state_dict + + +def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig: + if checkpoint == "small": + # default config values + hidden_size = 1024 + num_hidden_layers = 24 + num_attention_heads = 16 + elif checkpoint == "medium": + hidden_size = 1536 + num_hidden_layers = 48 + num_attention_heads = 24 + elif checkpoint == "large": + hidden_size = 2048 + num_hidden_layers = 48 + num_attention_heads = 32 + else: + raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.") + config = MusicgenDecoderConfig( + hidden_size=hidden_size, + ffn_dim=hidden_size * 4, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + ) + return config + + +@torch.no_grad() +def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"): + fairseq_model = MusicGen.get_pretrained(checkpoint, device=device) + decoder_config = decoder_config_from_checkpoint(checkpoint) + + decoder_state_dict = fairseq_model.lm.state_dict() + decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict( + decoder_state_dict, hidden_size=decoder_config.hidden_size + ) + + text_encoder = T5EncoderModel.from_pretrained("t5-base") + audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz") + decoder = MusicgenForCausalLM(decoder_config).eval() + + # load all decoder weights - expect that we'll be missing embeddings and enc-dec projection + missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False) + + for key in missing_keys.copy(): + if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS: + missing_keys.remove(key) + + if len(missing_keys) > 0: + raise ValueError(f"Missing key(s) in state_dict: {missing_keys}") + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}") + + # init the composite model + model = MusicgenForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder) + + # load the pre-trained enc-dec projection (from the decoder state dict) + model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict) + + # check we can do a forward pass + input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1) + decoder_input_ids = input_ids.reshape(2 * 4, -1) + + with torch.no_grad(): + logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits + + if logits.shape != (8, 1, 2048): + raise ValueError("Incorrect shape for logits") + + # now construct the processor + tokenizer = AutoTokenizer.from_pretrained("t5-base") + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left") + + processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # set the appropriate bos/pad token ids + model.generation_config.decoder_start_token_id = 2048 + model.generation_config.pad_token_id = 2048 + + # set other default generation config params + model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate) + model.generation_config.do_sample = True + model.generation_config.guidance_scale = 3.0 + + if pytorch_dump_folder is not None: + Path(pytorch_dump_folder).mkdir(exist_ok=True) + logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}") + model.save_pretrained(pytorch_dump_folder) + processor.save_pretrained(pytorch_dump_folder) + + if repo_id: + logger.info(f"Pushing model {checkpoint} to {repo_id}") + model.push_to_hub(repo_id) + processor.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint", + default="small", + type=str, + help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.", + ) + parser.add_argument( + "--pytorch_dump_folder", + required=True, + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + parser.add_argument( + "--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." + ) + + args = parser.parse_args() + convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py new file mode 100644 index 0000000000..bcd83e476f --- /dev/null +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -0,0 +1,2512 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Musicgen model.""" +import copy +import inspect +import math +import random +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from ...activations import ACT2FN +from ...generation.configuration_utils import GenerationConfig +from ...generation.logits_process import LogitsProcessorList +from ...generation.stopping_criteria import StoppingCriteriaList +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + ModelOutput, + Seq2SeqLMOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel +from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MusicgenConfig" +_CHECKPOINT_FOR_DOC = "facebook/musicgen-small" + +MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/musicgen-small", + # See all Musicgen models at https://huggingface.co/models?filter=musicgen +] + + +@dataclass +class MusicgenUnconditionalInput(ModelOutput): + """ + Args: + encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the text encoder model. + attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*): + Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0, + 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**. + guidance_scale (`float`, *optional*): + Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted + from the prompts) and the unconditional logits (predicted without prompts). + """ + + encoder_outputs: Tuple[torch.FloatTensor] = None + attention_mask: torch.LongTensor = None + guidance_scale: float = None + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class MusicgenSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int): + super().__init__() + self.embedding_dim = embedding_dim + self.make_weights(num_positions, embedding_dim) + + def make_weights(self, num_embeddings: int, embedding_dim: int): + emb_weights = self.get_embedding(num_embeddings, embedding_dim) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.weights = nn.Parameter(emb_weights) + self.weights.requires_grad = False + self.weights.detach_() + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, codebooks, seq_len = input_ids.size() + # Create the position ids from the input token ids. + position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) + # expand embeddings if needed + if seq_len > self.weights.size(0): + self.make_weights(seq_len + self.offset, self.embedding_dim) + return self.weights.index_select(0, position_ids.view(-1)).detach() + + +# Copied from transformers.models.bart.modeling_bart.BartAttention +class MusicgenAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class MusicgenDecoderLayer(nn.Module): + def __init__(self, config: MusicgenDecoderConfig): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = MusicgenAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=False, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = MusicgenAttention( + self.embed_dim, + config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=False, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MusicgenPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MusicgenDecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] + + def _init_weights(self, module): + std = self.config.initializer_factor + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MusicgenDecoder): + module.gradient_checkpointing = value + + +MUSICGEN_START_DOCSTRING = r""" + + The Musicgen model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by + Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an + encoder decoder transformer trained on the task of conditional music generation + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MusicgenConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MUSICGEN_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + + Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, + such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + + + The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks, + target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If + you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of + frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, + target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as + `decoder_input_ids`. + + + + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you + can choose to directly pass an embedded representation. This is useful if you want more control over how to + convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`): + Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes. + + Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes, + such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + + + The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks, + target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If + you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of + frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks, + target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as + `input_ids`. + + + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you + can choose to directly pass an embedded representation. This is useful if you want more control over how to + convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class MusicgenDecoder(MusicgenPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`] + """ + + def __init__(self, config: MusicgenDecoderConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.max_target_positions = config.max_position_embeddings + self.d_model = config.hidden_size + self.num_codebooks = config.num_codebooks + self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + embed_dim = config.vocab_size + 1 + self.embed_tokens = nn.ModuleList( + [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] + ) + + self.embed_positions = MusicgenSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.hidden_size, + ) + + self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len) + input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) + bsz, num_codebooks, seq_len = input.shape + input_shape = (bsz, seq_len) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1:] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = torch.zeros((bsz, seq_len, self.d_model), device=input_ids.device) + + for codebook in range(num_codebooks): + inputs_embeds += self.embed_tokens[codebook](input[:, codebook]) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare Musicgen decoder model outputting raw hidden-states without any specific head on top.", + MUSICGEN_START_DOCSTRING, +) +class MusicgenModel(MusicgenPreTrainedModel): + def __init__(self, config: MusicgenDecoderConfig): + super().__init__(config) + self.decoder = MusicgenDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The MusicGen decoder model with a language modelling head on top.", + MUSICGEN_START_DOCSTRING, +) +class MusicgenForCausalLM(MusicgenPreTrainedModel): + def __init__(self, config: MusicgenDecoderConfig): + super().__init__(config) + + self.model = MusicgenModel(config) + + self.num_codebooks = config.num_codebooks + self.lm_heads = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_heads + + def set_output_embeddings(self, new_embeddings): + self.lm_heads = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + Returns: + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1) + + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented for Musicgen.") + + # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) + lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=True, + delay_pattern_mask=None, + guidance_scale=None, + **kwargs, + ): + if delay_pattern_mask is None: + input_ids, delay_pattern_mask = self.build_delay_pattern_mask( + input_ids, + pad_token_id=self.generation_config.pad_token_id, + max_length=self.generation_config.max_length, + ) + + # apply the delay pattern mask + input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask) + + if guidance_scale is not None and guidance_scale > 1: + # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these + # before sampling) + input_ids = input_ids.repeat((2, 1)) + if attention_mask is not None: + attention_mask = attention_mask.repeat((2, 1)) + + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None): + """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by + one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there + are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, + seq_len)`: + - [P, -1, -1, -1, -1, P, P, P] + - [P, P, -1, -1, -1, -1, P, P] + - [P, P, P, -1, -1, -1, -1, P] + - [P, P, P, P, -1, -1, -1, -1] + where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include + a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the + mask is set to the value in the prompt: + - [P, a, b, -1, -1, P, P, P] + - [P, P, c, d, -1, -1, P, P] + - [P, P, P, e, f, -1, -1, P] + - [P, P, P, P, g, h, -1, -1] + where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 + tokens in our prediction. + """ + # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) + input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) + bsz, num_codebooks, seq_len = input_ids.shape + + max_length = max_length if max_length is not None else self.generation_config.max_length + input_ids_shifted = ( + torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 + ) + + # we only apply the mask if we have a large enough seq len - otherwise we return as is + if max_length < 2 * num_codebooks - 1: + return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) + + # fill the shifted ids with the prompt entries, offset by the codebook idx + for codebook in range(num_codebooks): + input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] + + # construct a pattern mask that indicates the positions of padding tokens for each codebook + # first fill the upper triangular part (the EOS padding) + delay_pattern = torch.triu( + torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1 + ) + # then fill the lower triangular part (the BOS padding) + delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool)) + mask = ~delay_pattern.to(input_ids.device) + input_ids = mask * input_ids_shifted + ~mask * pad_token_id + + # find the first position to start generating - this is the first place we have the -1 token + # and will always be in the first codebook (since it has no codebook offset) + first_codebook_ids = input_ids[:, 0, :] + start_ids = (first_codebook_ids == -1).nonzero()[:, 1] + if len(start_ids) > 0: + first_start_id = min(start_ids) + else: + # we have no tokens that need to be filled - return entire matrix of input ids + first_start_id = seq_len + + # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) + pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) + input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) + return input_ids, pattern_mask + + @staticmethod + def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): + """Apply a delay pattern mask to the decoder input ids, only preserving predictions where + the mask is set to -1, and otherwise setting to the value detailed in the mask.""" + seq_len = input_ids.shape[-1] + decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] + input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) + return input_ids + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + synced_gpus: Optional[bool] = None, + **kwargs, + ): + """ + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects + if generation_config is None: + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = input_ids.shape[0] // self.num_codebooks + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + model_kwargs["guidance_scale"] = generation_config.guidance_scale + + requires_attention_mask = "encoder_outputs" not in model_kwargs + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + input_ids, generation_config.pad_token_id, generation_config.eos_token_id + ) + + # 5. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + logger.warning( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + if not has_default_max_length: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + logger.warning( + f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 6. Prepare `input_ids` which will be used for auto-regressive generation + # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) + input_ids, delay_pattern_mask = self.build_delay_pattern_mask( + input_ids, + pad_token_id=generation_config.decoder_start_token_id, + max_length=generation_config.max_length, + ) + + # stash the delay mask so that we don't have to recompute it in each forward pass + model_kwargs["delay_pattern_mask"] = delay_pattern_mask + + # 7. determine generation mode + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + ) + + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing greedy search, " + f"but is {generation_config.num_return_sequences}." + ) + + # 8. run greedy search + outputs = self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 9. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + **model_kwargs, + ) + + # 10. run sample + outputs = self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + else: + raise ValueError( + "Got incompatible mode for generation, should be one of greedy or sampling." + "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." + ) + + if generation_config.return_dict_in_generate: + output_ids = outputs.sequences + else: + output_ids = outputs + + # apply the pattern mask to the final ids + output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) + + # revert the pattern delay mask by filtering the pad token id + output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( + batch_size, self.num_codebooks, -1 + ) + + if generation_config.return_dict_in_generate: + outputs.sequences = output_ids + return outputs + else: + return output_ids + + +@add_start_docstrings( + "The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder," + "for music generation tasks with one or both of text and audio prompts.", + MUSICGEN_START_DOCSTRING, +) +class MusicgenForConditionalGeneration(PreTrainedModel): + config_class = MusicgenConfig + base_model_prefix = "encoder_decoder" + main_input_name = "input_ids" + supports_gradient_checkpointing = True + + def __init__( + self, + config: Optional[MusicgenConfig] = None, + text_encoder: Optional[PreTrainedModel] = None, + audio_encoder: Optional[PreTrainedModel] = None, + decoder: Optional[MusicgenForCausalLM] = None, + ): + if config is None and (text_encoder is None or audio_encoder is None or decoder is None): + raise ValueError( + "Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder." + ) + if config is None: + config = MusicgenConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal" + f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for" + " `config.text_encoder.hidden_size`." + ) + + # initialize with config + super().__init__(config) + + if text_encoder is None: + from ..auto.modeling_auto import AutoModelForTextEncoding + + text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder) + + if audio_encoder is None: + from ..auto.modeling_auto import AutoModel + + audio_encoder = AutoModel.from_config(config.audio_encoder) + + if decoder is None: + decoder = MusicgenForCausalLM(config.decoder) + + self.text_encoder = text_encoder + self.audio_encoder = audio_encoder + self.decoder = decoder + + if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict(): + logger.warning( + f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:" + f" {self.config.text_encoder}" + ) + if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): + logger.warning( + f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" + f" {self.config.audio_encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.text_encoder.config = self.config.text_encoder + self.audio_encoder.config = self.config.audio_encoder + self.decoder.config = self.config.decoder + + # text encoder outputs might need to be projected to different dimension for decoder + if ( + self.text_encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.text_encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head" + ) + + decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) + if "encoder_hidden_states" not in decoder_signature: + raise ValueError( + "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " + "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" + ) + + # tie text encoder, decoder weights if config set accordingly + self.tie_weights() + + def tie_weights(self): + # tie text encoder & decoder if needed + if self.config.tie_encoder_decoder: + # tie text encoder and decoder base model + decoder_base_model_prefix = self.decoder.base_model_prefix + self._tie_encoder_decoder_weights( + self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix + ) + + def _set_gradient_checkpointing(self, module, value=False): + # call both encoder and decoder function on gradient checkpointing + self.text_encoder._set_gradient_checkpointing(module, value=value) + self.decoder._set_gradient_checkpointing(module, value=value) + + def get_audio_encoder(self): + return self.audio_encoder + + def get_text_encoder(self): + return self.text_encoder + + def get_encoder(self): + # get the text encoder to compute the encoder hidden-states for generation + return self.get_text_encoder() + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.text_encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Example: + + ```python + >>> from transformers import MusicgenForConditionalGeneration + + >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + ```""" + + # At the moment fast initialization is not supported for composite models + if kwargs.get("_fast_init", False): + logger.warning( + "Fast initialization is currently not supported for MusicgenForConditionalGeneration. " + "Falling back to slow initialization..." + ) + kwargs["_fast_init"] = False + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @classmethod + def from_sub_models_pretrained( + cls, + text_encoder_pretrained_model_name_or_path: str = None, + audio_encoder_pretrained_model_name_or_path: str = None, + decoder_pretrained_model_name_or_path: str = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + r""" + Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the + library from pretrained model checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + text_encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the text encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `t5-base`, or namespaced under a user or + organization name, like `google/flan-t5-base. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + audio_encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the audio encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `facebook/encodec_24khz`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or + organization name, like `facebook/musicgen-small`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration + parameter. + - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration + parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import MusicgenForConditionalGeneration + + >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder + >>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained( + ... text_encoder_pretrained_model_name_or_path="t5-base", + ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz", + ... decoder_pretrained_model_name_or_path="facebook/musicgen-small", + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./musicgen-ft") + >>> # load fine-tuned model + >>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft") + ```""" + + kwargs_text_encoder = { + argument[len("text_encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("text_encoder_") + } + + kwargs_audio_encoder = { + argument[len("audio_encoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("audio_encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove text encoder, audio encoder and decoder kwargs from kwargs + for key in kwargs_text_encoder.keys(): + del kwargs["text_encoder_" + key] + for key in kwargs_audio_encoder.keys(): + del kwargs["audio_encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + text_encoder = kwargs_text_encoder.pop("model", None) + if text_encoder is None: + if text_encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_text_encoder: + encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( + text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_text_encoder["config"] = encoder_config + + text_encoder = AutoModel.from_pretrained( + text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder + ) + + audio_encoder = kwargs_audio_encoder.pop("model", None) + if audio_encoder is None: + if audio_encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_audio_encoder: + encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( + audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_audio_encoder["config"] = encoder_config + + audio_encoder = AutoModel.from_pretrained( + audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder + ) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if isinstance(decoder_config, MusicgenConfig): + decoder_config = decoder_config.decoder + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_sub_models_pretrained(...)`" + ) + + decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = MusicgenConfig.from_sub_models_config( + text_encoder.config, audio_encoder.config, decoder.config, **kwargs + ) + return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) + + @add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + input_values: Optional[torch.FloatTensor] = None, + padding_mask: Optional[torch.BoolTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") + >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + + >>> inputs = processor( + ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + ... padding=True, + ... return_tensors="pt", + ... ) + + >>> pad_token_id = model.generation_config.pad_token_id + >>> decoder_input_ids = ( + ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long) + ... * pad_token_id + ... ) + + >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits + >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size) + torch.Size([8, 1, 2048]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_text_encoder = { + argument[len("text_encoder_")]: value + for argument, value in kwargs.items() + if argument.startswith("text_encoder_") + } + + kwargs_audio_encoder = { + argument[len("audio_encoder_")]: value + for argument, value in kwargs.items() + if argument.startswith("audio_encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + encoder_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_text_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.text_encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if attention_mask is not None: + encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + elif decoder_input_ids is None and decoder_inputs_embeds is None: + audio_encoder_outputs = self.audio_encoder( + input_values=input_values, + padding_mask=padding_mask, + **kwargs_audio_encoder, + ) + audio_codes = audio_encoder_outputs.audio_codes + frames, bsz, codebooks, seq_len = audio_codes.shape + if frames != 1: + raise ValueError( + f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " + "disabled by setting `chunk_length=None` in the audio encoder." + ) + decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + **kwargs_decoder, + ) + + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[0] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_attention_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + decoder_delay_pattern_mask=None, + guidance_scale=None, + **kwargs, + ): + if decoder_delay_pattern_mask is None: + decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( + decoder_input_ids, + self.generation_config.pad_token_id, + max_length=self.generation_config.max_length, + ) + + # apply the delay pattern mask + decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask) + + if guidance_scale is not None and guidance_scale > 1: + # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these + # before sampling) + decoder_input_ids = decoder_input_ids.repeat((2, 1)) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = ( + torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) + * decoder_start_token_id + ) + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs + + def _prepare_text_encoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str] = None, + guidance_scale: Optional[float] = None, + ) -> Dict[str, Any]: + # 1. get text encoder + encoder = self.get_text_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + + # 2. Prepare encoder args and encoder kwargs from model kwargs. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + last_hidden_state = encoder(**encoder_kwargs).last_hidden_state + + # for classifier free guidance we need to add a 'null' input to our encoder hidden states + if guidance_scale is not None and guidance_scale > 1: + last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0) + if "attention_mask" in model_kwargs: + model_kwargs["attention_mask"] = torch.concatenate( + [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 + ) + + model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) + + return model_kwargs + + def _prepare_audio_encoder_kwargs_for_generation( + self, input_values, model_kwargs, model_input_name: Optional[str] = None + ): + # 1. get audio encoder + encoder = self.get_audio_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + + # 2. Prepare encoder args and encoder kwargs from model kwargs. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = input_values + + audio_encoder_outputs = encoder.encode(**encoder_kwargs) + + audio_codes = audio_encoder_outputs.audio_codes + frames, bsz, codebooks, seq_len = audio_codes.shape + + if frames != 1: + raise ValueError( + f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " + "disabled by setting `chunk_length=None` in the audio encoder." + ) + + decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) + + model_kwargs["decoder_input_ids"] = decoder_input_ids + model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales + return model_kwargs + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs[0].size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + synced_gpus: Optional[bool] = None, + **kwargs, + ): + """ + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects + if generation_config is None: + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple: + # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate + model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0]) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + model_kwargs["guidance_scale"] = generation_config.guidance_scale + + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + ) + + if "encoder_outputs" not in model_kwargs: + # encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_text_encoder_kwargs_for_generation( + inputs_tensor, + model_kwargs, + model_input_name, + guidance_scale=generation_config.guidance_scale, + ) + + if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: + model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( + model_kwargs["input_values"], + model_kwargs, + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + device=inputs_tensor.device, + ) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + logger.warning( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + if not has_default_max_length: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + logger.warning( + f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) + input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( + input_ids, + pad_token_id=generation_config.decoder_start_token_id, + max_length=generation_config.max_length, + ) + # stash the delay mask so that we don't have to recompute in each forward pass + model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask + + # 7. determine generation mode + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + ) + + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing greedy search, " + f"but is {generation_config.num_return_sequences}." + ) + + # 10. run greedy search + outputs = self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 12. run sample + outputs = self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + else: + raise ValueError( + "Got incompatible mode for generation, should be one of greedy or sampling." + "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." + ) + + if generation_config.return_dict_in_generate: + output_ids = outputs.sequences + else: + output_ids = outputs + + # apply the pattern mask to the final ids + output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) + + # revert the pattern delay mask by filtering the pad token id + output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( + batch_size, self.decoder.num_codebooks, -1 + ) + + # append the frame dimension back to the audio codes + output_ids = output_ids[None, ...] + + audio_scales = model_kwargs.get("audio_scales") + if audio_scales is None: + audio_scales = [None] * batch_size + + output_values = self.audio_encoder.decode( + output_ids, + audio_scales=audio_scales, + ) + + if generation_config.return_dict_in_generate: + outputs.sequences = output_values.audio_values + return outputs + else: + return output_values.audio_values + + def get_unconditional_inputs(self, num_samples=1): + """ + Helper function to get null inputs for unconditional generation, enabling the model to be used without the + feature extractor or tokenizer. + + Args: + num_samples (int, *optional*): + Number of audio samples to unconditionally generate. + max_new_tokens (int, *optional*): + Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of + longer inference (since more audio tokens need to be generated per sample). + + Example: + ```python + >>> from transformers import MusicgenForConditionalGeneration + + >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + + >>> # get the unconditional (or 'null') inputs for the model + >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256) + ```""" + last_hidden_state = torch.zeros( + (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype + ) + + attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long) + + return MusicgenUnconditionalInput( + encoder_outputs=(last_hidden_state,), + attention_mask=attention_mask, + guidance_scale=1.0, + ) diff --git a/src/transformers/models/musicgen/processing_musicgen.py b/src/transformers/models/musicgen/processing_musicgen.py new file mode 100644 index 0000000000..ed8d1277f2 --- /dev/null +++ b/src/transformers/models/musicgen/processing_musicgen.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Text/audio processor class for MusicGen +""" +from typing import List, Optional + +import numpy as np + +from ...processing_utils import ProcessorMixin +from ...utils import to_numpy + + +class MusicgenProcessor(ProcessorMixin): + r""" + Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor + class. + + [`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See + [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information. + + Args: + feature_extractor (`EncodecFeatureExtractor`): + An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`T5Tokenizer`): + An instance of [`T5Tokenizer`]. The tokenizer is a required input. + """ + feature_extractor_class = "EncodecFeatureExtractor" + tokenizer_class = ("T5Tokenizer", "T5TokenizerFast") + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): + return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps) + + def __call__(self, *args, **kwargs): + """ + Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text` + argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more + information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") + + if text is not None: + inputs = self.tokenizer(text, **kwargs) + + if audio is not None: + audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + + if audio is None: + return inputs + + elif text is None: + return audio_inputs + + else: + inputs["input_values"] = audio_inputs["input_values"] + if "padding_mask" in audio_inputs: + inputs["padding_mask"] = audio_inputs["padding_mask"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids + from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's + [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. + """ + audio_values = kwargs.pop("audio", None) + padding_mask = kwargs.pop("padding_mask", None) + + if len(args) > 0: + audio_values = args[0] + args = args[1:] + + if audio_values is not None: + return self._decode_audio(audio_values, padding_mask=padding_mask) + else: + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the + docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]: + """ + This method strips any padding from the audio values to return a list of numpy audio arrays. + """ + audio_values = to_numpy(audio_values) + bsz, channels, seq_len = audio_values.shape + + if padding_mask is None: + return list(audio_values) + + padding_mask = to_numpy(padding_mask) + + # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding** + # token (so that the generated audio values are **not** treated as padded tokens) + difference = seq_len - padding_mask.shape[-1] + padding_value = 1 - self.feature_extractor.padding_value + padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value) + + audio_values = audio_values.tolist() + for i in range(bsz): + sliced_audio = np.asarray(audio_values[i])[ + padding_mask[i][None, :] != self.feature_extractor.padding_value + ] + audio_values[i] = sliced_audio.reshape(channels, -1) + + return audio_values diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4513032607..2c40f7143d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4936,6 +4936,44 @@ class MT5PreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class MusicgenForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MusicgenForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MusicgenModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MusicgenPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MusicgenProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/musicgen/__init__.py b/tests/models/musicgen/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py new file mode 100644 index 0000000000..00f249b09d --- /dev/null +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -0,0 +1,1346 @@ +# coding=utf-8 +# Copyright 2021, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Musicgen model. """ +import copy +import inspect +import math +import unittest + +import numpy as np + +from transformers import ( + EncodecConfig, + MusicgenConfig, + MusicgenDecoderConfig, + MusicgenProcessor, + PretrainedConfig, + T5Config, +) +from transformers.testing_utils import is_torch_available, require_torch, slow, torch_device +from transformers.utils import cached_property + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MusicgenForCausalLM, + MusicgenForConditionalGeneration, + MusicgenModel, + set_seed, + ) + from transformers.generation import ( + GreedySearchDecoderOnlyOutput, + GreedySearchEncoderDecoderOutput, + InfNanRemoveLogitsProcessor, + LogitsProcessorList, + SampleDecoderOnlyOutput, + SampleEncoderDecoderOutput, + ) + + +def _config_zero_init(config): + configs_no_init = copy.deepcopy(config) + for key in configs_no_init.__dict__.keys(): + if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: + setattr(configs_no_init, key, 1e-10) + if isinstance(getattr(configs_no_init, key, None), PretrainedConfig): + no_init_subconfig = _config_zero_init(getattr(configs_no_init, key)) + setattr(configs_no_init, key, no_init_subconfig) + return configs_no_init + + +def prepare_musicgen_decoder_inputs_dict( + config, + input_ids, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + cross_attn_head_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :] + attention_mask = attention_mask.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device) + if encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + } + + +class MusicgenDecoderTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=7, + is_training=False, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=99, + bos_token_id=99, + num_codebooks=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size) + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + + config = self.get_config() + inputs_dict = prepare_musicgen_decoder_inputs_dict( + config, input_ids, encoder_hidden_states=encoder_hidden_states + ) + return config, inputs_dict + + def get_config(self): + config = MusicgenDecoderConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + d_ff=self.intermediate_size, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.bos_token_id, + bos_token_id=self.bos_token_id, + num_codebooks=self.num_codebooks, + tie_word_embeddings=False, + ) + return config + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MusicgenModel, MusicgenForCausalLM) if is_torch_available() else () + greedy_sample_model_classes = ( + (MusicgenForCausalLM,) if is_torch_available() else () + ) # we don't want to run all the generation tests, only a specific subset + pipeline_model_mapping = {} + test_pruning = False + test_resize_embeddings = False + + def setUp(self): + self.model_tester = MusicgenDecoderTester(self) + self.config_tester = ConfigTester(self, config_class=MusicgenDecoderConfig, hidden_size=16) + + def test_config(self): + self.config_tester.run_common_tests() + + # override since we have to compute the input embeddings over codebooks + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + + embed_tokens = model.get_input_embeddings() + + input_ids = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1]) + + inputs["inputs_embeds"] = sum( + [embed_tokens[codebook](input_ids[:, codebook]) for codebook in range(config.num_codebooks)] + ) + + with torch.no_grad(): + model(**inputs)[0] + + # override since we have embeddings / LM heads over multiple codebooks + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + first_embed = model.get_input_embeddings()[0] + self.assertIsInstance(first_embed, torch.nn.Embedding) + lm_heads = model.get_output_embeddings() + self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + + # skip as this model doesn't support all arguments tested + def test_model_outputs_equivalence(self): + pass + + # skip as this model has multiple inputs embeds and lm heads that should not be tied + def test_tie_model_weights(self): + pass + + # skip as this model has multiple inputs embeds and lm heads that should not be tied + def test_tied_weights_keys(self): + pass + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + + # take max batch_size + sequence_length = input_ids.shape[-1] + input_ids = input_ids[: batch_size * config.num_codebooks, :] + + # generate max 3 tokens + max_length = input_ids.shape[-1] + 3 + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) + return config, input_ids, attention_mask, max_length + + @staticmethod + def _get_logits_processor_and_kwargs( + input_length, + eos_token_id, + forced_bos_token_id=None, + forced_eos_token_id=None, + max_length=None, + diversity_penalty=None, + ): + process_kwargs = { + "min_length": input_length + 1 if max_length is None else max_length - 1, + } + logits_processor = LogitsProcessorList() + return process_kwargs, logits_processor + + # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform + # additional post-processing in the former + def test_greedy_generate_dict_outputs(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform + # additional post-processing in the former + def test_greedy_generate_dict_outputs_use_cache(self): + for model_class in self.greedy_sample_model_classes: + # enable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) + self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) + + # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform + # additional post-processing in the former + def test_sample_generate(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) + + # check `generate()` and `sample()` are equal + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=3, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, + ) + self.assertIsInstance(output_sample, torch.Tensor) + self.assertIsInstance(output_generate, torch.Tensor) + + # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform + # additional post-processing in the former + def test_sample_generate_dict_output(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=1, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) + self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) + + +def prepare_musicgen_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, +): + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.reshape( + -1, config.decoder.num_codebooks, decoder_input_ids.shape[-1] + )[:, 0, :] + decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id) + if head_mask is None: + head_mask = torch.ones( + config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device + ) + if decoder_head_mask is None: + decoder_head_mask = torch.ones( + config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device + ) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones( + config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device + ) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + } + + +class MusicgenTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=7, + is_training=False, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + pad_token_id=99, + bos_token_id=99, + num_codebooks=4, + num_filters=4, + codebook_size=128, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.num_codebooks = num_codebooks + self.num_filters = num_filters + self.codebook_size = codebook_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size) + + config = self.get_config() + inputs_dict = prepare_musicgen_inputs_dict(config, input_ids, decoder_input_ids=decoder_input_ids) + return config, inputs_dict + + def get_config(self): + text_encoder_config = T5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + ) + audio_encoder_config = EncodecConfig( + hidden_size=self.vocab_size, + compress=1, + num_filters=self.num_filters, + codebook_size=self.codebook_size, + codebook_dim=self.vocab_size, + ) + decoder_config = MusicgenDecoderConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + ffn_dim=self.intermediate_size, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.bos_token_id, + bos_token_id=self.bos_token_id, + num_codebooks=self.num_codebooks, + tie_word_embeddings=False, + ) + config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config) + return config + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + +@require_torch +class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else () + greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {} + test_pruning = False # training is not supported yet for MusicGen + test_headmasking = False + test_resize_embeddings = False + + def setUp(self): + self.model_tester = MusicgenTester(self) + + def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids): + text_encoder_config = config.text_encoder + decoder_config = config.decoder + + encoder_attentions = outputs["encoder_attentions"] + self.assertEqual(len(encoder_attentions), text_encoder_config.num_hidden_layers) + + self.assertEqual( + encoder_attentions[0].shape[-3:], + (text_encoder_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), + ) + + decoder_attentions = outputs["decoder_attentions"] + num_decoder_layers = decoder_config.num_hidden_layers + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = decoder_input_ids.shape[-1] + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), + ) + + def check_musicgen_model_output_attentions( + self, + model_class, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + **kwargs, + ) + self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids) + + def check_musicgen_model_output_attentions_from_config( + self, + model_class, + config, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ): + # Similar to `check_musicgen_model_output_attentions`, but with `output_attentions` triggered from the + # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded + # from the inner models' configurations. + config.output_attentions = True # model config -> won't work + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + self.assertTrue( + all(key not in outputs for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]) + ) + config.text_encoder.output_attentions = True # inner model config -> will work + config.audio_encoder.output_attentions = True + config.decoder.output_attentions = True + + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + **kwargs, + ) + self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids) + + # override since changing `output_attentions` from the top-level model config won't work + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict) + self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict) + + # override since we have a specific forward signature for musicgen + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = [ + "input_ids", + "attention_mask", + "input_values", + "padding_mask", + "decoder_input_ids", + "decoder_attention_mask", + ] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + # override since changing `gradient_checkpointing` from the top-level model config won't work + def test_gradient_checkpointing_backward_compatibility(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + config.text_encoder.gradient_checkpointing = True + config.audio_encoder.gradient_checkpointing = True + config.decoder.gradient_checkpointing = True + model = model_class(config) + self.assertTrue(model.is_gradient_checkpointing) + + # skip as this model has multiple inputs embeds and lm heads that should not be tied + def test_tie_model_weights(self): + pass + + # skip as this model has multiple inputs embeds and lm heads that should not be tied + def test_tied_model_weights_key_ignore(self): + pass + + # skip as this model has multiple inputs embeds and lm heads that should not be tied + def test_tied_weights_keys(self): + pass + + # override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.text_encoder.output_hidden_states = True + config.audio_encoder.output_hidden_states = True + config.decoder.output_hidden_states = True + + config.text_encoder.output_attentions = True + config.decoder.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + decoder_hidden_states = outputs.decoder_hidden_states[0] + decoder_hidden_states.retain_grad() + + if self.has_attentions: + encoder_attentions = outputs.encoder_attentions[0] + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(decoder_hidden_states.grad) + + if self.has_attentions: + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + # override since changing `output_hidden_states` from the top-level model config won't work + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states + + expected_num_layers = self.model_tester.num_hidden_layers + 1 + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_length = self.model_tester.seq_length + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + hidden_states = outputs.decoder_hidden_states + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.text_encoder.output_hidden_states = True + config.audio_encoder.output_hidden_states = True + config.decoder.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # override since the conv layers and lstm's in encodec are exceptions + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv"] + ignore_init = ["lstm"] + if param.requires_grad: + if any([x in name for x in uniform_init_parms]): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif not any([x in name for x in ignore_init]): + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # override since we have embeddings / LM heads over multiple codebooks + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), torch.nn.Embedding) + lm_heads = model.get_output_embeddings() + self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) + + def _get_input_ids_and_config(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + input_ids = inputs_dict["input_ids"] + + # take max batch_size + sequence_length = input_ids.shape[-1] + input_ids = input_ids[:batch_size, :] + attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) + + # generate max 3 tokens + decoder_input_ids = inputs_dict["decoder_input_ids"] + max_length = decoder_input_ids.shape[-1] + 3 + decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :] + return config, input_ids, attention_mask, decoder_input_ids, max_length + + # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are + # different modalities -> different shapes) + def _greedy_generate( + self, + model, + input_ids, + attention_mask, + decoder_input_ids, + max_length, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + eos_token_id=model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( + input_ids, + do_sample=False, + num_beams=1, + max_length=max_length, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + **logits_process_kwargs, + **model_kwargs, + ) + + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_greedy = model.greedy_search( + decoder_input_ids, + max_length=max_length, + logits_processor=logits_processor, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + encoder_outputs=encoder_outputs, + **model_kwargs, + ) + return output_greedy, output_generate + + # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are + # different modalities -> different shapes) + def _sample_generate( + self, + model, + input_ids, + attention_mask, + decoder_input_ids, + max_length, + num_return_sequences, + logits_processor, + logits_warper, + logits_warper_kwargs, + process_kwargs, + output_scores=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + ): + torch.manual_seed(0) + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_generate = model.generate( + input_ids, + do_sample=True, + num_beams=1, + max_length=max_length, + num_return_sequences=num_return_sequences, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + remove_invalid_values=True, + **logits_warper_kwargs, + **process_kwargs, + **model_kwargs, + ) + + torch.manual_seed(0) + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=num_return_sequences, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # prevent flaky generation test failures + logits_processor.append(InfNanRemoveLogitsProcessor()) + + with torch.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + output_sample = model.sample( + decoder_input_ids.repeat_interleave(num_return_sequences, dim=0), + max_length=max_length, + logits_processor=logits_processor, + logits_warper=logits_warper, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + encoder_outputs=encoder_outputs, + **model_kwargs, + ) + + return output_sample, output_generate + + @staticmethod + def _get_logits_processor_and_kwargs( + input_length, + eos_token_id, + forced_bos_token_id=None, + forced_eos_token_id=None, + max_length=None, + diversity_penalty=None, + ): + process_kwargs = { + "min_length": input_length + 1 if max_length is None else max_length - 1, + } + logits_processor = LogitsProcessorList() + return process_kwargs, logits_processor + + def test_greedy_generate_dict_outputs(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + def test_greedy_generate_dict_outputs_use_cache(self): + for model_class in self.greedy_sample_model_classes: + # enable cache + config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + output_greedy, output_generate = self._greedy_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + max_length=max_length, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) + self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) + + def test_sample_generate(self): + for model_class in self.greedy_sample_model_classes: + config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) + + # check `generate()` and `sample()` are equal + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + max_length=max_length, + num_return_sequences=1, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, + ) + self.assertIsInstance(output_sample, torch.Tensor) + self.assertIsInstance(output_generate, torch.Tensor) + + def test_sample_generate_dict_output(self): + for model_class in self.greedy_sample_model_classes: + # disable cache + config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config() + config.use_cache = False + model = model_class(config).to(torch_device).eval() + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + output_sample, output_generate = self._sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + max_length=max_length, + num_return_sequences=3, + logits_processor=logits_processor, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + process_kwargs=process_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) + self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) + + def test_generate_without_input_ids(self): + config, _, _, _, max_length = self._get_input_ids_and_config() + + # if no bos token id => cannot generate from None + if config.bos_token_id is None: + return + + for model_class in self.greedy_sample_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True) + self.assertIsNotNone(output_ids_generate) + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + + for model_class in self.greedy_sample_model_classes: + model = model_class(config).eval().to(torch_device) + if torch_device == "cuda": + model.half() + model.generate(**input_dict, max_new_tokens=10) + model.generate(**input_dict, do_sample=True, max_new_tokens=10) + + +def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): + """Produces a series of 'bip bip' sounds at a given frequency.""" + timesteps = np.arange(int(duration * sample_rate)) / sample_rate + wav = np.cos(2 * math.pi * 440 * timesteps) + time_period = (timesteps % (2 * bip_duration)) / (2 * bip_duration) + envelope = time_period >= 0.5 + return wav * envelope + + +def place_dict_on_device(dict_to_place, device): + for key in dict_to_place: + if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor): + dict_to_place[key] = dict_to_place[key].to(device) + return dict_to_place + + +@require_torch +class MusicgenIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to(torch_device) + + @cached_property + def processor(self): + return MusicgenProcessor.from_pretrained("facebook/musicgen-small") + + @slow + def test_logits_text_prompt(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + # prepare the decoder inputs + pad_token_id = model.generation_config.pad_token_id + decoder_input_ids = ( + torch.ones((input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long).to(torch_device) + * pad_token_id + ) + + with torch.no_grad(): + logits = model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + ).logits + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + -0.9708, -3.0149, -4.6415, -1.4754, -0.2786, -2.3523, -2.6049, -6.7467, + -1.0206, -3.2984, -3.3968, -1.5108, -1.5786, -3.1493, -1.1503, -0.0545, + ] + ) + # fmt: on + + self.assertTrue(logits.shape == (*decoder_input_ids.shape, model.decoder.config.vocab_size)) + self.assertTrue(torch.allclose(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @slow + def test_logits_text_audio_prompt(self): + model = self.model + processor = self.processor + + audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + + # prepare the text encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + # prepare the audio encoder inputs + input_values = inputs.input_values.to(torch_device) + padding_mask = inputs.padding_mask.to(torch_device) + + with torch.no_grad(): + logits = model( + input_ids, + attention_mask=attention_mask, + input_values=input_values, + padding_mask=padding_mask, + ).logits + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + 0.1841, -2.9324, -0.7898, 0.1857, 0.4971, -2.8685, -1.6525, -1.6541, + 2.7757, -2.5942, -3.0959, -1.0120, -1.0147, -0.4605, -0.8885, 0.6820, + ] + ) + # fmt: on + + self.assertTrue(logits.shape == (8, 50, 2048)) + self.assertTrue(torch.allclose(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, atol=1e-4)) + + @slow + def test_generate_unconditional_greedy(self): + model = self.model + + # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same + unconditional_inputs = model.get_unconditional_inputs(num_samples=1) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=5) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0056, 0.0064, 0.0063, 0.0054, 0.0042, 0.0033, 0.0024, 0.0015, + 0.0015, 0.0010, 0.0004, -0.0012, -0.0036, -0.0055, -0.0067, -0.0071, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (1, 1, 3200)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_unconditional_sampling(self): + model = self.model + + # for stochastic sampling we can generate multiple outputs + unconditional_inputs = model.get_unconditional_inputs(num_samples=2) + unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device) + + set_seed(0) + output_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=10) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0765, 0.0758, 0.0749, 0.0759, 0.0759, 0.0771, 0.0775, 0.0760, + 0.0762, 0.0765, 0.0767, 0.0760, 0.0738, 0.0714, 0.0713, 0.0730, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_greedy(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=None, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -1.1998e-04, -2.2302e-04, 4.6296e-04, 1.0524e-03, 2.4827e-04, + -4.0288e-05, -1.2468e-04, 4.9846e-05, 7.1485e-04, 4.4197e-04, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_greedy_with_classifier_free_guidance(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=3, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + 0.0283, 0.0246, 0.0650, 0.0640, 0.0599, 0.0711, 0.0420, 0.0112, + 0.0511, 0.0746, 0.1363, 0.1213, 0.0185, -0.0578, -0.0908, 0.0443, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_prompt_sampling(self): + model = self.model + processor = self.processor + + inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt") + + # prepare the encoder inputs + input_ids = inputs.input_ids.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + set_seed(0) + output_values = model.generate( + input_ids, attention_mask=attention_mask, do_sample=True, guidance_scale=None, max_new_tokens=10 + ) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0047, -0.0094, -0.0028, -0.0018, -0.0057, -0.0007, -0.0104, -0.0211, + -0.0097, -0.0150, -0.0066, -0.0004, -0.0201, -0.0325, -0.0326, -0.0098, + ] + ) + # fmt: on + + self.assertTrue(output_values.shape == (2, 1, 4480)) + self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4)) + + @slow + def test_generate_text_audio_prompt(self): + model = self.model + processor = self.processor + + audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)] + text = ["80s music", "Club techno"] + + inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt") + inputs = place_dict_on_device(inputs, device=torch_device) + + output_values = model.generate(**inputs, do_sample=False, guidance_scale=None, max_new_tokens=10) + + # fmt: off + EXPECTED_VALUES = torch.tensor( + [ + -0.0036, -0.0130, -0.0261, -0.0384, -0.0557, -0.0718, -0.0680, -0.0632, + -0.0529, -0.0403, -0.0289, -0.0198, -0.0136, -0.0101, -0.0095, -0.0040, + ] + ) + # fmt: on + + self.assertTrue( + output_values.shape == (2, 1, 36480) + ) # input values take shape 32000 and we generate from there + self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4)) diff --git a/tests/models/musicgen/test_processing_musicgen.py b/tests/models/musicgen/test_processing_musicgen.py new file mode 100644 index 0000000000..41962552ff --- /dev/null +++ b/tests/models/musicgen/test_processing_musicgen.py @@ -0,0 +1,173 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the MusicGen processor.""" + +import random +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers import T5Tokenizer, T5TokenizerFast +from transformers.testing_utils import require_sentencepiece, require_torch +from transformers.utils.import_utils import is_speech_available, is_torch_available + + +if is_torch_available(): + pass + +if is_speech_available(): + from transformers import EncodecFeatureExtractor, MusicgenProcessor + + +global_rng = random.Random() + + +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +@require_torch +@require_sentencepiece +class MusicgenProcessorTest(unittest.TestCase): + def setUp(self): + self.checkpoint = "facebook/musicgen-small" + self.tmpdirname = tempfile.mkdtemp() + + def get_tokenizer(self, **kwargs): + return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs) + + def get_feature_extractor(self, **kwargs): + return EncodecFeatureExtractor.from_pretrained(self.checkpoint, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + processor.save_pretrained(self.tmpdirname) + processor = MusicgenProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, T5TokenizerFast) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = MusicgenProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) + + processor = MusicgenProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, T5TokenizerFast) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + raw_speech = floats_list((3, 1000)) + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "This is a test string" + + encoded_processor = processor(text=input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(sequences=predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) + + def test_decode_audio(self): + feature_extractor = self.get_feature_extractor(padding_side="left") + tokenizer = self.get_tokenizer() + + processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + raw_speech = [floats_list((1, x))[0] for x in range(5, 20, 5)] + padding_mask = processor(raw_speech).padding_mask + + generated_speech = np.asarray(floats_list((3, 20)))[:, None, :] + decoded_audios = processor.batch_decode(generated_speech, padding_mask=padding_mask) + + self.assertIsInstance(decoded_audios, list) + + for audio in decoded_audios: + self.assertIsInstance(audio, np.ndarray) + + self.assertTrue(decoded_audios[0].shape == (1, 10)) + self.assertTrue(decoded_audios[1].shape == (1, 15)) + self.assertTrue(decoded_audios[2].shape == (1, 20)) diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index 93385b127d..de47348a9e 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -37,6 +37,7 @@ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)") CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { "DecisionTransformerConfig", "EncoderDecoderConfig", + "MusicgenConfig", "RagConfig", "SpeechEncoderDecoderConfig", "TimmBackboneConfig", diff --git a/utils/check_repo.py b/utils/check_repo.py index 2de9fa61e1..46407eb1a5 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -119,6 +119,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ "MegatronBertEncoder", # Building part of bigger (tested) model. "MegatronBertDecoder", # Building part of bigger (tested) model. "MegatronBertDecoderWrapper", # Building part of bigger (tested) model. + "MusicgenDecoder", # Building part of bigger (tested) model. "MvpDecoderWrapper", # Building part of bigger (tested) model. "MvpEncoder", # Building part of bigger (tested) model. "PegasusEncoder", # Building part of bigger (tested) model. @@ -331,6 +332,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "SpeechT5ForSpeechToSpeech", "SpeechT5ForTextToSpeech", "SpeechT5HifiGan", + "MusicgenModel", + "MusicgenForConditionalGeneration", ] # Update this list for models that have multiple model types for the same diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 0064a9999b..2bbaabb073 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -134,6 +134,9 @@ src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py src/transformers/models/mobilevit/modeling_mobilevit.py src/transformers/models/mobilevit/modeling_tf_mobilevit.py +src/transformers/models/musicgen/modeling_musicgen.py +src/transformers/models/musicgen/configuration_musicgen.py +src/transformers/models/musicgen/processing_musicgen.py src/transformers/models/mvp/configuration_mvp.py src/transformers/models/nat/configuration_nat.py src/transformers/models/nat/modeling_nat.py