diff --git a/README.md b/README.md
index 2a8e515e14..cc1fd458f5 100644
--- a/README.md
+++ b/README.md
@@ -411,6 +411,7 @@ Current number of checkpoints: ** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
diff --git a/README_es.md b/README_es.md
index f5077993aa..fcb6049870 100644
--- a/README_es.md
+++ b/README_es.md
@@ -386,6 +386,7 @@ Número actual de puntos de control: ** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
diff --git a/README_hd.md b/README_hd.md
index 146af0585c..9b694ed607 100644
--- a/README_hd.md
+++ b/README_hd.md
@@ -358,6 +358,7 @@ conda install -c huggingface transformers
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple से) Sachin Mehta and Mohammad Rastegari. द्वाराअनुसंधान पत्र [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) के साथ जारी किया गया
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI से) साथ वाला पेपर [mT5: एक व्यापक बहुभाषी पूर्व-प्रशिक्षित टेक्स्ट-टू-टेक्स्ट ट्रांसफॉर्मर]( https://arxiv.org/abs/2010.11934) लिंटिंग ज़ू, नोआ कॉन्सटेंट, एडम रॉबर्ट्स, मिहिर काले, रामी अल-रफू, आदित्य सिद्धांत, आदित्य बरुआ, कॉलिन रैफेल द्वारा पोस्ट किया गया।
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (हुआवेई नूह के आर्क लैब से) साथ में कागज़ [NEZHA: चीनी भाषा समझ के लिए तंत्रिका प्रासंगिक प्रतिनिधित्व](https :/ /arxiv.org/abs/1909.00204) जुन्किउ वेई, ज़ियाओज़े रेन, ज़िआओगुआंग ली, वेनयोंग हुआंग, यी लियाओ, याशेंग वांग, जियाशू लिन, शिन जियांग, जिओ चेन और कुन लियू द्वारा।
diff --git a/README_ja.md b/README_ja.md
index 5584a84fac..60b14191c1 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -420,6 +420,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple から) Sachin Mehta and Mohammad Rastegari. から公開された研究論文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research から) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu から公開された研究論文: [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297)
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI から) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel から公開された研究論文: [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934)
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box から) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen から公開された研究論文: [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131)
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (SHI Labs から) Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi から公開された研究論文: [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143)
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab から) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu から公開された研究論文: [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204)
diff --git a/README_ko.md b/README_ko.md
index a476f4493b..cdbeec9a4b 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -335,6 +335,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple 에서 제공)은 Sachin Mehta and Mohammad Rastegari.의 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)논문과 함께 발표했습니다.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research 에서) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 의 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 논문과 함께 발표했습니다.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI 에서) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 의 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 논문과 함께 발표했습니다.
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box 에서) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 의 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 논문과 함께 발표했습니다.
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (SHI Labs 에서) Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi 의 [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 논문과 함께 발표했습니다.
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab 에서) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 의 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 논문과 함께 발표했습니다.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index bdc661091a..db1cbd4237 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -359,6 +359,7 @@ conda install -c huggingface transformers
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (来自 Apple) 伴随论文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) 由 Sachin Mehta and Mohammad Rastegari 发布。
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (来自 中国人民大学 AI Box) 伴随论文 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 由 Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 发布。
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (来自 SHI Labs) 伴随论文 [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 由 Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi 发布。
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (来自华为诺亚方舟实验室) 伴随论文 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 由 Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index 94a275dc21..e66cd1f867 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -371,6 +371,7 @@ conda install -c huggingface transformers
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
+1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 414e67ffe7..e36282b21c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -549,6 +549,8 @@
title: MCTCT
- local: model_doc/mms
title: MMS
+ - local: model_doc/musicgen
+ title: MusicGen
- local: model_doc/sew
title: SEW
- local: model_doc/sew-d
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 4c7f238815..91c57e0f39 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -175,6 +175,7 @@ The documentation is organized into five sections:
1. **[MobileViTV2](model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
1. **[MPNet](model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
+1. **[MusicGen](model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
1. **[MVP](model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
1. **[NAT](model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
1. **[Nezha](model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
@@ -380,6 +381,7 @@ Flax), PyTorch, and/or TensorFlow.
| MobileViTV2 | ❌ | ❌ | ✅ | ❌ | ❌ |
| MPNet | ✅ | ✅ | ✅ | ✅ | ❌ |
| MT5 | ✅ | ✅ | ✅ | ✅ | ✅ |
+| MusicGen | ❌ | ❌ | ✅ | ❌ | ❌ |
| MVP | ✅ | ✅ | ✅ | ❌ | ❌ |
| NAT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Nezha | ❌ | ❌ | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/musicgen.md b/docs/source/en/model_doc/musicgen.md
new file mode 100644
index 0000000000..72250a86fc
--- /dev/null
+++ b/docs/source/en/model_doc/musicgen.md
@@ -0,0 +1,277 @@
+
+
+# MusicGen
+
+## Overview
+
+The MusicGen model was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284)
+by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
+
+MusicGen is a single stage auto-regressive Transformer model capable of generating high-quality music samples conditioned
+on text descriptions or audio prompts. The text descriptions are passed through a frozen text encoder model to obtain a
+sequence of hidden-state representations. MusicGen is then trained to predict discrete audio tokens, or *audio codes*,
+conditioned on these hidden-states. These audio tokens are then decoded using an audio compression model, such as EnCodec,
+to recover the audio waveform.
+
+Through an efficient token interleaving pattern, MusicGen does not require a self-supervised semantic representation of
+the text/audio prompts, thus eliminating the need to cascade multiple models to predict a set of codebooks (e.g.
+hierarchically or upsampling). Instead, it is able to generate all the codebooks in a single forward pass.
+
+The abstract from the paper is the following:
+
+*We tackle the task of conditional music generation. We introduce MusicGen, a single Language Model (LM) that operates
+over several streams of compressed discrete music representation, i.e., tokens. Unlike prior work, MusicGen is comprised
+of a single-stage transformer LM together with efficient token interleaving patterns, which eliminates the need for
+cascading several models, e.g., hierarchically or upsampling. Following this approach, we demonstrate how MusicGen
+can generate high-quality samples, while being conditioned on textual description or melodic features, allowing better
+controls over the generated output. We conduct extensive empirical evaluation, considering both automatic and human
+studies, showing the proposed approach is superior to the evaluated baselines on a standard text-to-music benchmark.
+Through ablation studies, we shed light over the importance of each of the components comprising MusicGen.*
+
+This model was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original code can be found
+[here](https://github.com/facebookresearch/audiocraft). The pre-trained checkpoints can be found on the
+[Hugging Face Hub](https://huggingface.co/models?sort=downloads&search=facebook%2Fmusicgen-).
+
+## Generation
+
+MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly
+better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,
+and can be explicitly specified by setting `do_sample=True` in the call to [`MusicgenForConditionalGeneration.generate`],
+or by overriding the model's generation config (see below).
+
+### Unconditional Generation
+
+The inputs for unconditional (or 'null') generation can be obtained through the method
+[`MusicgenForConditionalGeneration.get_unconditional_inputs`]:
+
+```python
+>>> from transformers import MusicgenForConditionalGeneration
+
+>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
+
+>>> audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
+```
+
+The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen
+to the generated audio samples, you can either play them in an ipynb notebook:
+
+```python
+from IPython.display import Audio
+
+sampling_rate = model.config.audio_encoder.sampling_rate
+Audio(audio_values[0].numpy(), rate=sampling_rate)
+```
+
+Or save them as a `.wav` file using a third-party library, e.g. `scipy`:
+
+```python
+>>> import scipy
+
+>>> sampling_rate = model.config.audio_encoder.sampling_rate
+>>> scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
+```
+
+### Text-Conditional Generation
+
+The model can generate an audio sample conditioned on a text prompt through use of the [`MusicgenProcessor`] to pre-process
+the inputs:
+
+```python
+>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+
+>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+>>> inputs = processor(
+... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+... padding=True,
+... return_tensors="pt",
+... )
+>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
+```
+
+The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits
+(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or
+'null' prompt). Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,
+use `guidance_scale=3` (default).
+
+### Audio-Prompted Generation
+
+The same [`MusicgenProcessor`] can be used to pre-process an audio prompt that is used for audio continuation. In the
+following example, we load an audio file using the 🤗 Datasets library, which can be pip installed through the command
+below:
+
+```
+pip install --upgrade pip
+pip install datasets[audio]
+```
+
+```python
+>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+>>> from datasets import load_dataset
+
+>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True)
+>>> sample = next(iter(dataset))["audio"]
+
+>>> # take the first half of the audio sample
+>>> sample["array"] = sample["array"][: len(sample["array"]) // 2]
+
+>>> inputs = processor(
+... audio=sample["array"],
+... sampling_rate=sample["sampling_rate"],
+... text=["80s blues track with groovy saxophone"],
+... padding=True,
+... return_tensors="pt",
+... )
+>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
+```
+
+For batched audio-prompted generation, the generated `audio_values` can be post-processed to remove padding by using the
+[`MusicgenProcessor`] class:
+
+```python
+>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+>>> from datasets import load_dataset
+
+>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True)
+>>> sample = next(iter(dataset))["audio"]
+
+>>> # take the first quarter of the audio sample
+>>> sample_1 = sample["array"][: len(sample["array"]) // 4]
+
+>>> # take the first half of the audio sample
+>>> sample_2 = sample["array"][: len(sample["array"]) // 2]
+
+>>> inputs = processor(
+... audio=[sample_1, sample_2],
+... sampling_rate=sample["sampling_rate"],
+... text=["80s blues track with groovy saxophone", "90s rock song with loud guitars and heavy drums"],
+... padding=True,
+... return_tensors="pt",
+... )
+>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
+
+>>> # post-process to remove padding from the batched audio
+>>> audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)
+```
+
+### Generation Configuration
+
+The default parameters that control the generation process, such as sampling, guidance scale and number of generated
+tokens, can be found in the model's generation config, and updated as desired:
+
+```python
+>>> from transformers import MusicgenForConditionalGeneration
+
+>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+>>> # inspect the default generation config
+>>> model.generation_config
+
+>>> # increase the guidance scale to 4.0
+>>> model.generation_config.guidance_scale = 4.0
+
+>>> # decrease the max length to 256 tokens
+>>> model.generation_config.max_length = 256
+```
+
+Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting
+`do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the
+generation config.
+
+## Model Structure
+
+The MusicGen model can be de-composed into three distinct stages:
+1. Text encoder: maps the text inputs to a sequence of hidden-state representations. The pre-trained MusicGen models use a frozen text encoder from either T5 or Flan-T5
+2. MusicGen decoder: a language model (LM) that auto-regressively generates audio tokens (or codes) conditional on the encoder hidden-state representations
+3. Audio encoder/decoder: used to encode an audio prompt to use as prompt tokens, and recover the audio waveform from the audio tokens predicted by the decoder
+
+Thus, the MusicGen model can either be used as a standalone decoder model, corresponding to the class [`MusicgenForCausalLM`],
+or as a composite model that includes the text encoder and audio encoder/decoder, corresponding to the class
+[`MusicgenForConditionalGeneration`].
+
+Since the text encoder and audio encoder/decoder models are frozen during training, the MusicGen decoder [`MusicgenForCausalLM`]
+can be trained standalone on a dataset of encoder hidden-states and audio codes. For inference, the trained decoder can
+be combined with the frozen text encoder and audio encoder/decoders to recover the composite [`MusicgenForConditionalGeneration`]
+model.
+
+Below, we demonstrate how to construct the composite [`MusicgenForConditionalGeneration`] model from its three constituent
+parts, as would typically be done following training of the MusicGen decoder LM:
+
+```python
+>>> from transformers import AutoConfig, AutoModelForTextEncoding, AutoModel, MusicgenForCausalLM, MusicgenForConditionalGeneration
+
+>>> text_encoder = AutoModelForTextEncoding.from_pretrained("t5-base")
+>>> audio_encoder = AutoModel.from_pretrained("facebook/encodec_32khz")
+>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder
+>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config)
+
+>>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(text_encoder, audio_encoder, decoder)
+```
+
+If only the decoder needs to be loaded from the pre-trained checkpoint for the composite model, it can be loaded by first
+specifying the correct config, or be accessed through the `.decoder` attribute of the composite model:
+
+```python
+>>> from transformers import AutoConfig, MusicgenForCausalLM, MusicgenForConditionalGeneration
+
+>>> # Option 1: get decoder config and pass to `.from_pretrained`
+>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder
+>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config)
+
+>>> # Option 2: load the entire composite model, but only return the decoder
+>>> decoder = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").decoder
+```
+
+Tips:
+* MusicGen is trained on the 32kHz checkpoint of Encodec. You should ensure you use a compatible version of the Encodec model.
+* Sampling mode tends to deliver better results than greedy - you can toggle sampling with the variable `do_sample` in the call to [`MusicgenForConditionalGeneration.generate`]
+
+## MusicgenDecoderConfig
+
+[[autodoc]] MusicgenDecoderConfig
+
+## MusicgenConfig
+
+[[autodoc]] MusicgenConfig
+
+## MusicgenProcessor
+
+[[autodoc]] MusicgenProcessor
+
+## MusicgenModel
+
+[[autodoc]] MusicgenModel
+ - forward
+
+## MusicgenForCausalLM
+
+[[autodoc]] MusicgenForCausalLM
+ - forward
+
+## MusicgenForConditionalGeneration
+
+[[autodoc]] MusicgenForConditionalGeneration
+ - forward
diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md
index f6f9b37afe..4986d90b28 100644
--- a/docs/source/en/tasks/language_modeling.md
+++ b/docs/source/en/tasks/language_modeling.md
@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
-[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
+[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 202d9a6101..3747e1951f 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -402,6 +402,11 @@ _import_structure = {
"models.mobilevitv2": ["MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTV2Config"],
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"],
"models.mt5": ["MT5Config"],
+ "models.musicgen": [
+ "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "MusicgenConfig",
+ "MusicgenDecoderConfig",
+ ],
"models.mvp": ["MvpConfig", "MvpTokenizer"],
"models.nat": ["NAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "NatConfig"],
"models.nezha": ["NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP", "NezhaConfig"],
@@ -2134,6 +2139,16 @@ else:
_import_structure["models.mt5"].extend(
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"]
)
+ _import_structure["models.musicgen"].extend(
+ [
+ "MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MusicgenForCausalLM",
+ "MusicgenForConditionalGeneration",
+ "MusicgenModel",
+ "MusicgenPreTrainedModel",
+ "MusicgenProcessor",
+ ]
+ )
_import_structure["models.mvp"].extend(
[
"MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -4249,6 +4264,11 @@ if TYPE_CHECKING:
from .models.mobilevitv2 import MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTV2Config
from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer
from .models.mt5 import MT5Config
+ from .models.musicgen import (
+ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ MusicgenConfig,
+ MusicgenDecoderConfig,
+ )
from .models.mvp import MvpConfig, MvpTokenizer
from .models.nat import NAT_PRETRAINED_CONFIG_ARCHIVE_MAP, NatConfig
from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig
@@ -5709,6 +5729,14 @@ if TYPE_CHECKING:
MT5Model,
MT5PreTrainedModel,
)
+ from .models.musicgen import (
+ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ MusicgenForCausalLM,
+ MusicgenForConditionalGeneration,
+ MusicgenModel,
+ MusicgenPreTrainedModel,
+ MusicgenProcessor,
+ )
from .models.mvp import (
MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
MvpForCausalLM,
diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py
index 99426790db..4514ccef76 100644
--- a/src/transformers/generation/configuration_utils.py
+++ b/src/transformers/generation/configuration_utils.py
@@ -185,6 +185,10 @@ class GenerationConfig(PushToHubMixin):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
+ guidance_scale (`float`, *optional*):
+ The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
+ Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+ prompt, usually at the expense of poorer quality.
> Parameters that define the output variables of `generate`
@@ -265,6 +269,7 @@ class GenerationConfig(PushToHubMixin):
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
+ self.guidance_scale = kwargs.pop("guidance_scale", None)
# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py
index 6a73fcf964..a7fad67845 100644
--- a/src/transformers/generation/logits_process.py
+++ b/src/transformers/generation/logits_process.py
@@ -1065,3 +1065,40 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
scores[k, : self.timestamp_begin] = -float("inf")
return scores
+
+
+class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
+ r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension,
+ where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
+ correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
+ weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
+
+ Args:
+ guidance_scale (float):
+ The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
+ Higher guidance scale encourages the model to generate samples that are more closely linked to the input
+ prompt, usually at the expense of poorer quality.
+ """
+
+ def __init__(self, guidance_scale):
+ if guidance_scale > 1:
+ self.guidance_scale = guidance_scale
+ else:
+ raise ValueError(
+ "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
+ f"{guidance_scale}."
+ )
+
+ def __call__(self, input_ids, scores):
+ # simple check to make sure we have compatible batch sizes between our
+ # logits scores (cond + uncond) and input ids (cond only)
+ if scores.shape[0] != 2 * input_ids.shape[0]:
+ raise ValueError(
+ f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
+ f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
+ f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
+ )
+ unguided_bsz = scores.shape[0] // 2
+ cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
+ scores = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
+ return scores
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index cb681cc720..69120ab8f8 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -38,6 +38,7 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig
from .logits_process import (
+ ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
@@ -940,6 +941,8 @@ class GenerationMixin:
)
if generation_config.forced_decoder_ids is not None:
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 771ba9be09..d8345c9ef8 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -135,6 +135,7 @@ from . import (
mobilevitv2,
mpnet,
mt5,
+ musicgen,
mvp,
nat,
nezha,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 2f2714e619..5cbaa0705a 100755
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -138,6 +138,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("mobilevitv2", "MobileViTV2Config"),
("mpnet", "MPNetConfig"),
("mt5", "MT5Config"),
+ ("musicgen", "MusicgenConfig"),
("mvp", "MvpConfig"),
("nat", "NatConfig"),
("nezha", "NezhaConfig"),
@@ -327,6 +328,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("mobilevit", "MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mobilevitv2", "MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("musicgen", "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mvp", "MVP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("nat", "NAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("nezha", "NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -532,6 +534,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mobilevitv2", "MobileViTV2"),
("mpnet", "MPNet"),
("mt5", "MT5"),
+ ("musicgen", "MusicGen"),
("mvp", "MVP"),
("nat", "NAT"),
("nezha", "Nezha"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 8c5e5e8c15..8bb6ea37aa 100755
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -389,6 +389,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
+ ("musicgen", "MusicgenForCausalLM"),
("mvp", "MvpForCausalLM"),
("open-llama", "OpenLlamaForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
diff --git a/src/transformers/models/musicgen/__init__.py b/src/transformers/models/musicgen/__init__.py
new file mode 100644
index 0000000000..7fa695eba8
--- /dev/null
+++ b/src/transformers/models/musicgen/__init__.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {
+ "configuration_musicgen": [
+ "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "MusicgenConfig",
+ "MusicgenDecoderConfig",
+ ],
+ "processing_musicgen": ["MusicgenProcessor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_musicgen"] = [
+ "MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "MusicgenForConditionalGeneration",
+ "MusicgenForCausalLM",
+ "MusicgenModel",
+ "MusicgenPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_musicgen import (
+ MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ MusicgenConfig,
+ MusicgenDecoderConfig,
+ )
+ from .processing_musicgen import MusicgenProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_musicgen import (
+ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ MusicgenForCausalLM,
+ MusicgenForConditionalGeneration,
+ MusicgenModel,
+ MusicgenPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py
new file mode 100644
index 0000000000..2882c49f75
--- /dev/null
+++ b/src/transformers/models/musicgen/configuration_musicgen.py
@@ -0,0 +1,243 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" MusicGen model configuration"""
+import copy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "facebook/musicgen-small": "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json",
+ # See all Musicgen models at https://huggingface.co/models?filter=musicgen
+}
+
+
+class MusicgenDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
+ MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MusicGen
+ [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 2048):
+ Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of decoder layers.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer block.
+ ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_factor (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(hidden_size).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models)
+ num_codebooks (`int`, *optional*, defaults to 4):
+ The number of parallel codebooks forwarded to the model.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether input and output word embeddings should be tied.
+ """
+ model_type = "musicgen_decoder"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=2048,
+ max_position_embeddings=2048,
+ num_hidden_layers=24,
+ ffn_dim=4096,
+ num_attention_heads=16,
+ layerdrop=0.0,
+ use_cache=True,
+ activation_function="gelu",
+ hidden_size=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ initializer_factor=0.02,
+ scale_embedding=False,
+ num_codebooks=4,
+ pad_token_id=2048,
+ bos_token_id=2048,
+ eos_token_id=None,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.ffn_dim = ffn_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.initializer_factor = initializer_factor
+ self.layerdrop = layerdrop
+ self.use_cache = use_cache
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.num_codebooks = num_codebooks
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MusicgenConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
+ MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the text encoder config.
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the audio encoder config.
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the decoder config.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... MusicgenConfig,
+ ... MusicgenDecoderConfig,
+ ... T5Config,
+ ... EncodecConfig,
+ ... MusicgenForConditionalGeneration,
+ ... )
+
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
+ >>> text_encoder_config = T5Config()
+ >>> audio_encoder_config = EncodecConfig()
+ >>> decoder_config = MusicgenDecoderConfig()
+
+ >>> configuration = MusicgenConfig.from_sub_models_config(
+ ... text_encoder_config, audio_encoder_config, decoder_config
+ ... )
+
+ >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
+ >>> model = MusicgenForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> config_text_encoder = model.config.text_encoder
+ >>> config_audio_encoder = model.config.audio_encoder
+ >>> config_decoder = model.config.decoder
+
+ >>> # Saving the model, including its configuration
+ >>> model.save_pretrained("musicgen-model")
+
+ >>> # loading model and config from pretrained folder
+ >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
+ ```"""
+
+ model_type = "musicgen"
+ is_composition = True
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
+
+ text_encoder_config = kwargs.pop("text_encoder")
+ text_encoder_model_type = text_encoder_config.pop("model_type")
+
+ audio_encoder_config = kwargs.pop("audio_encoder")
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
+
+ decoder_config = kwargs.pop("decoder")
+
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
+ self.decoder = MusicgenDecoderConfig(**decoder_config)
+ self.is_encoder_decoder = True
+
+ @classmethod
+ def from_sub_models_config(
+ cls,
+ text_encoder_config: PretrainedConfig,
+ audio_encoder_config: PretrainedConfig,
+ decoder_config: MusicgenDecoderConfig,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
+ configurations.
+
+ Returns:
+ [`MusicgenConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ text_encoder=text_encoder_config.to_dict(),
+ audio_encoder=audio_encoder_config.to_dict(),
+ decoder=decoder_config.to_dict(),
+ **kwargs,
+ )
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["text_encoder"] = self.text_encoder.to_dict()
+ output["audio_encoder"] = self.audio_encoder.to_dict()
+ output["decoder"] = self.decoder.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
diff --git a/src/transformers/models/musicgen/convert_musicgen_transformers.py b/src/transformers/models/musicgen/convert_musicgen_transformers.py
new file mode 100644
index 0000000000..517f0099d0
--- /dev/null
+++ b/src/transformers/models/musicgen/convert_musicgen_transformers.py
@@ -0,0 +1,209 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert MusicGen checkpoints from the original repository."""
+import argparse
+from pathlib import Path
+from typing import Dict, OrderedDict, Tuple
+
+import torch
+from audiocraft.models import MusicGen
+
+from transformers import (
+ AutoFeatureExtractor,
+ AutoTokenizer,
+ EncodecModel,
+ MusicgenDecoderConfig,
+ MusicgenForConditionalGeneration,
+ MusicgenProcessor,
+ T5EncoderModel,
+)
+from transformers.models.musicgen.modeling_musicgen import MusicgenForCausalLM
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"]
+
+
+def rename_keys(name):
+ if "emb" in name:
+ name = name.replace("emb", "model.decoder.embed_tokens")
+ if "transformer" in name:
+ name = name.replace("transformer", "model.decoder")
+ if "cross_attention" in name:
+ name = name.replace("cross_attention", "encoder_attn")
+ if "linear1" in name:
+ name = name.replace("linear1", "fc1")
+ if "linear2" in name:
+ name = name.replace("linear2", "fc2")
+ if "norm1" in name:
+ name = name.replace("norm1", "self_attn_layer_norm")
+ if "norm_cross" in name:
+ name = name.replace("norm_cross", "encoder_attn_layer_norm")
+ if "norm2" in name:
+ name = name.replace("norm2", "final_layer_norm")
+ if "out_norm" in name:
+ name = name.replace("out_norm", "model.decoder.layer_norm")
+ if "linears" in name:
+ name = name.replace("linears", "lm_heads")
+ if "condition_provider.conditioners.description.output_proj" in name:
+ name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj")
+ return name
+
+
+def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]:
+ """Function that takes the fairseq Musicgen state dict and renames it according to the HF
+ module names. It further partitions the state dict into the decoder (LM) state dict, and that for the
+ encoder-decoder projection."""
+ keys = list(state_dict.keys())
+ enc_dec_proj_state_dict = {}
+ for key in keys:
+ val = state_dict.pop(key)
+ key = rename_keys(key)
+ if "in_proj_weight" in key:
+ # split fused qkv proj
+ state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :]
+ state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :]
+ state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :]
+ elif "enc_to_dec_proj" in key:
+ enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val
+ else:
+ state_dict[key] = val
+ return state_dict, enc_dec_proj_state_dict
+
+
+def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
+ if checkpoint == "small":
+ # default config values
+ hidden_size = 1024
+ num_hidden_layers = 24
+ num_attention_heads = 16
+ elif checkpoint == "medium":
+ hidden_size = 1536
+ num_hidden_layers = 48
+ num_attention_heads = 24
+ elif checkpoint == "large":
+ hidden_size = 2048
+ num_hidden_layers = 48
+ num_attention_heads = 32
+ else:
+ raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
+ config = MusicgenDecoderConfig(
+ hidden_size=hidden_size,
+ ffn_dim=hidden_size * 4,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ )
+ return config
+
+
+@torch.no_grad()
+def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
+ fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
+ decoder_config = decoder_config_from_checkpoint(checkpoint)
+
+ decoder_state_dict = fairseq_model.lm.state_dict()
+ decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict(
+ decoder_state_dict, hidden_size=decoder_config.hidden_size
+ )
+
+ text_encoder = T5EncoderModel.from_pretrained("t5-base")
+ audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz")
+ decoder = MusicgenForCausalLM(decoder_config).eval()
+
+ # load all decoder weights - expect that we'll be missing embeddings and enc-dec projection
+ missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False)
+
+ for key in missing_keys.copy():
+ if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS:
+ missing_keys.remove(key)
+
+ if len(missing_keys) > 0:
+ raise ValueError(f"Missing key(s) in state_dict: {missing_keys}")
+
+ if len(unexpected_keys) > 0:
+ raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}")
+
+ # init the composite model
+ model = MusicgenForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder)
+
+ # load the pre-trained enc-dec projection (from the decoder state dict)
+ model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
+
+ # check we can do a forward pass
+ input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
+ decoder_input_ids = input_ids.reshape(2 * 4, -1)
+
+ with torch.no_grad():
+ logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
+
+ if logits.shape != (8, 1, 2048):
+ raise ValueError("Incorrect shape for logits")
+
+ # now construct the processor
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
+ feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
+
+ processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ # set the appropriate bos/pad token ids
+ model.generation_config.decoder_start_token_id = 2048
+ model.generation_config.pad_token_id = 2048
+
+ # set other default generation config params
+ model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate)
+ model.generation_config.do_sample = True
+ model.generation_config.guidance_scale = 3.0
+
+ if pytorch_dump_folder is not None:
+ Path(pytorch_dump_folder).mkdir(exist_ok=True)
+ logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
+ model.save_pretrained(pytorch_dump_folder)
+ processor.save_pretrained(pytorch_dump_folder)
+
+ if repo_id:
+ logger.info(f"Pushing model {checkpoint} to {repo_id}")
+ model.push_to_hub(repo_id)
+ processor.push_to_hub(repo_id)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint",
+ default="small",
+ type=str,
+ help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder",
+ required=True,
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
+ )
+ parser.add_argument(
+ "--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
+ )
+
+ args = parser.parse_args()
+ convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
new file mode 100644
index 0000000000..bcd83e476f
--- /dev/null
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -0,0 +1,2512 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Musicgen model."""
+import copy
+import inspect
+import math
+import random
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+from torch.utils.checkpoint import checkpoint
+
+from ...activations import ACT2FN
+from ...generation.configuration_utils import GenerationConfig
+from ...generation.logits_process import LogitsProcessorList
+from ...generation.stopping_criteria import StoppingCriteriaList
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ ModelOutput,
+ Seq2SeqLMOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel
+from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "MusicgenConfig"
+_CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
+
+MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/musicgen-small",
+ # See all Musicgen models at https://huggingface.co/models?filter=musicgen
+]
+
+
+@dataclass
+class MusicgenUnconditionalInput(ModelOutput):
+ """
+ Args:
+ encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the text encoder model.
+ attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
+ Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
+ 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
+ guidance_scale (`float`, *optional*):
+ Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
+ from the prompts) and the unconditional logits (predicted without prompts).
+ """
+
+ encoder_outputs: Tuple[torch.FloatTensor] = None
+ attention_mask: torch.LongTensor = None
+ guidance_scale: float = None
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class MusicgenSinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.make_weights(num_positions, embedding_dim)
+
+ def make_weights(self, num_embeddings: int, embedding_dim: int):
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim)
+ if hasattr(self, "weights"):
+ # in forward put the weights on the correct dtype and device of the param
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
+
+ self.weights = nn.Parameter(emb_weights)
+ self.weights.requires_grad = False
+ self.weights.detach_()
+
+ @staticmethod
+ def get_embedding(num_embeddings: int, embedding_dim: int):
+ """
+ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
+ description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ return emb.to(torch.get_default_dtype())
+
+ @torch.no_grad()
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ bsz, codebooks, seq_len = input_ids.size()
+ # Create the position ids from the input token ids.
+ position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
+ # expand embeddings if needed
+ if seq_len > self.weights.size(0):
+ self.make_weights(seq_len + self.offset, self.embedding_dim)
+ return self.weights.index_select(0, position_ids.view(-1)).detach()
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention
+class MusicgenAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class MusicgenDecoderLayer(nn.Module):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = MusicgenAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = MusicgenAttention(
+ self.embed_dim,
+ config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class MusicgenPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MusicgenDecoderConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_factor
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, MusicgenDecoder):
+ module.gradient_checkpointing = value
+
+
+MUSICGEN_START_DOCSTRING = r"""
+
+ The Musicgen model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
+ Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
+ encoder decoder transformer trained on the task of conditional music generation
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MusicgenConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MUSICGEN_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+
+
+ The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `decoder_input_ids`.
+
+
+
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
+ the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MusicgenDecoder(MusicgenPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`]
+ """
+
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.layerdrop
+ self.max_target_positions = config.max_position_embeddings
+ self.d_model = config.hidden_size
+ self.num_codebooks = config.num_codebooks
+ self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ embed_dim = config.vocab_size + 1
+ self.embed_tokens = nn.ModuleList(
+ [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
+ )
+
+ self.embed_positions = MusicgenSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ )
+
+ self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
+ input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input.shape
+ input_shape = (bsz, seq_len)
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1:]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = torch.zeros((bsz, seq_len, self.d_model), device=input_ids.device)
+
+ for codebook in range(num_codebooks):
+ inputs_embeds += self.embed_tokens[codebook](input[:, codebook])
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {attn_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, use_cache)
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare Musicgen decoder model outputting raw hidden-states without any specific head on top.",
+ MUSICGEN_START_DOCSTRING,
+)
+class MusicgenModel(MusicgenPreTrainedModel):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+ self.decoder = MusicgenDecoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The MusicGen decoder model with a language modelling head on top.",
+ MUSICGEN_START_DOCSTRING,
+)
+class MusicgenForCausalLM(MusicgenPreTrainedModel):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+
+ self.model = MusicgenModel(config)
+
+ self.num_codebooks = config.num_codebooks
+ self.lm_heads = nn.ModuleList(
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_heads
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_heads = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ Returns:
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
+
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented for Musicgen.")
+
+ # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
+ lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=True,
+ delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ if delay_pattern_mask is None:
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ input_ids = input_ids.repeat((2, 1))
+ if attention_mask is not None:
+ attention_mask = attention_mask.repeat((2, 1))
+
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "head_mask": head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None):
+ """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
+ one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
+ are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
+ seq_len)`:
+ - [P, -1, -1, -1, -1, P, P, P]
+ - [P, P, -1, -1, -1, -1, P, P]
+ - [P, P, P, -1, -1, -1, -1, P]
+ - [P, P, P, P, -1, -1, -1, -1]
+ where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
+ a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
+ mask is set to the value in the prompt:
+ - [P, a, b, -1, -1, P, P, P]
+ - [P, P, c, d, -1, -1, P, P]
+ - [P, P, P, e, f, -1, -1, P]
+ - [P, P, P, P, g, h, -1, -1]
+ where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
+ tokens in our prediction.
+ """
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input_ids.shape
+
+ max_length = max_length if max_length is not None else self.generation_config.max_length
+ input_ids_shifted = (
+ torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
+ )
+
+ # we only apply the mask if we have a large enough seq len - otherwise we return as is
+ if max_length < 2 * num_codebooks - 1:
+ return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
+
+ # fill the shifted ids with the prompt entries, offset by the codebook idx
+ for codebook in range(num_codebooks):
+ input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
+
+ # construct a pattern mask that indicates the positions of padding tokens for each codebook
+ # first fill the upper triangular part (the EOS padding)
+ delay_pattern = torch.triu(
+ torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
+ )
+ # then fill the lower triangular part (the BOS padding)
+ delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
+ mask = ~delay_pattern.to(input_ids.device)
+ input_ids = mask * input_ids_shifted + ~mask * pad_token_id
+
+ # find the first position to start generating - this is the first place we have the -1 token
+ # and will always be in the first codebook (since it has no codebook offset)
+ first_codebook_ids = input_ids[:, 0, :]
+ start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
+ if len(start_ids) > 0:
+ first_start_id = min(start_ids)
+ else:
+ # we have no tokens that need to be filled - return entire matrix of input ids
+ first_start_id = seq_len
+
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
+ input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
+ return input_ids, pattern_mask
+
+ @staticmethod
+ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
+ """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
+ the mask is set to -1, and otherwise setting to the value detailed in the mask."""
+ seq_len = input_ids.shape[-1]
+ decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
+ input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
+ return input_ids
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
+ - [`~generation.SampleDecoderOnlyOutput`],
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
+ - [`~generation.SampleEncoderDecoderOutput`],
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
+ if model_kwargs.get("attention_mask", None) is None:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
+ generation_config.pad_token_id = eos_token_id
+
+ # 3. Define model inputs
+ # inputs_tensor has to be defined
+ # model_input_name is defined if model-specific keyword input is passed
+ # otherwise model_input_name is None
+ # all model-specific keyword inputs are removed from `model_kwargs`
+ input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = input_ids.shape[0] // self.num_codebooks
+
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ input_ids, generation_config.pad_token_id, generation_config.eos_token_id
+ )
+
+ # 5. Prepare `max_length` depending on other stopping criteria.
+ input_ids_seq_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ logger.warning(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ if not has_default_max_length:
+ logger.warning(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
+ raise ValueError(
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
+ )
+ if input_ids_seq_length >= generation_config.max_length:
+ logger.warning(
+ f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 6. Prepare `input_ids` which will be used for auto-regressive generation
+ # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config.decoder_start_token_id,
+ max_length=generation_config.max_length,
+ )
+
+ # stash the delay mask so that we don't have to recompute it in each forward pass
+ model_kwargs["delay_pattern_mask"] = delay_pattern_mask
+
+ # 7. determine generation mode
+ is_greedy_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ )
+ is_sample_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ )
+
+ # 8. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ )
+
+ # 9. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if is_greedy_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ "num_return_sequences has to be 1 when doing greedy search, "
+ f"but is {generation_config.num_return_sequences}."
+ )
+
+ # 8. run greedy search
+ outputs = self.greedy_search(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_sample_gen_mode:
+ # 9. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ **model_kwargs,
+ )
+
+ # 10. run sample
+ outputs = self.sample(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling."
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
+ batch_size, self.num_codebooks, -1
+ )
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_ids
+ return outputs
+ else:
+ return output_ids
+
+
+@add_start_docstrings(
+ "The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder,"
+ "for music generation tasks with one or both of text and audio prompts.",
+ MUSICGEN_START_DOCSTRING,
+)
+class MusicgenForConditionalGeneration(PreTrainedModel):
+ config_class = MusicgenConfig
+ base_model_prefix = "encoder_decoder"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ config: Optional[MusicgenConfig] = None,
+ text_encoder: Optional[PreTrainedModel] = None,
+ audio_encoder: Optional[PreTrainedModel] = None,
+ decoder: Optional[MusicgenForCausalLM] = None,
+ ):
+ if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
+ raise ValueError(
+ "Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder."
+ )
+ if config is None:
+ config = MusicgenConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal"
+ f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
+ " `config.text_encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if text_encoder is None:
+ from ..auto.modeling_auto import AutoModelForTextEncoding
+
+ text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
+
+ if audio_encoder is None:
+ from ..auto.modeling_auto import AutoModel
+
+ audio_encoder = AutoModel.from_config(config.audio_encoder)
+
+ if decoder is None:
+ decoder = MusicgenForCausalLM(config.decoder)
+
+ self.text_encoder = text_encoder
+ self.audio_encoder = audio_encoder
+ self.decoder = decoder
+
+ if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict():
+ logger.warning(
+ f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:"
+ f" {self.config.text_encoder}"
+ )
+ if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict():
+ logger.warning(
+ f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:"
+ f" {self.config.audio_encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ self.text_encoder.config = self.config.text_encoder
+ self.audio_encoder.config = self.config.audio_encoder
+ self.decoder.config = self.config.decoder
+
+ # text encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+ if self.text_encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ # tie text encoder, decoder weights if config set accordingly
+ self.tie_weights()
+
+ def tie_weights(self):
+ # tie text encoder & decoder if needed
+ if self.config.tie_encoder_decoder:
+ # tie text encoder and decoder base model
+ decoder_base_model_prefix = self.decoder.base_model_prefix
+ self._tie_encoder_decoder_weights(
+ self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
+ )
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ # call both encoder and decoder function on gradient checkpointing
+ self.text_encoder._set_gradient_checkpointing(module, value=value)
+ self.decoder._set_gradient_checkpointing(module, value=value)
+
+ def get_audio_encoder(self):
+ return self.audio_encoder
+
+ def get_text_encoder(self):
+ return self.text_encoder
+
+ def get_encoder(self):
+ # get the text encoder to compute the encoder hidden-states for generation
+ return self.get_text_encoder()
+
+ def get_decoder(self):
+ return self.decoder
+
+ def get_input_embeddings(self):
+ return self.text_encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+ ```"""
+
+ # At the moment fast initialization is not supported for composite models
+ if kwargs.get("_fast_init", False):
+ logger.warning(
+ "Fast initialization is currently not supported for MusicgenForConditionalGeneration. "
+ "Falling back to slow initialization..."
+ )
+ kwargs["_fast_init"] = False
+
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+ @classmethod
+ def from_sub_models_pretrained(
+ cls,
+ text_encoder_pretrained_model_name_or_path: str = None,
+ audio_encoder_pretrained_model_name_or_path: str = None,
+ decoder_pretrained_model_name_or_path: str = None,
+ *model_args,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
+ library from pretrained model checkpoints.
+
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ text_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the text encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids can be located at the root-level, like `t5-base`, or namespaced under a user or
+ organization name, like `google/flan-t5-base.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the audio encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
+ user or organization name, like `facebook/encodec_24khz`.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids can be located at the root-level, like `gpt2`, or namespaced under a user or
+ organization name, like `facebook/musicgen-small`.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
+ parameter.
+ - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
+ parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
+ >>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(
+ ... text_encoder_pretrained_model_name_or_path="t5-base",
+ ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
+ ... decoder_pretrained_model_name_or_path="facebook/musicgen-small",
+ ... )
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./musicgen-ft")
+ >>> # load fine-tuned model
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft")
+ ```"""
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove text encoder, audio encoder and decoder kwargs from kwargs
+ for key in kwargs_text_encoder.keys():
+ del kwargs["text_encoder_" + key]
+ for key in kwargs_audio_encoder.keys():
+ del kwargs["audio_encoder_" + key]
+ for key in kwargs_decoder.keys():
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ text_encoder = kwargs_text_encoder.pop("model", None)
+ if text_encoder is None:
+ if text_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_text_encoder:
+ encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_text_encoder["config"] = encoder_config
+
+ text_encoder = AutoModel.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
+ )
+
+ audio_encoder = kwargs_audio_encoder.pop("model", None)
+ if audio_encoder is None:
+ if audio_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_audio_encoder:
+ encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_audio_encoder["config"] = encoder_config
+
+ audio_encoder = AutoModel.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
+ )
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+
+ if isinstance(decoder_config, MusicgenConfig):
+ decoder_config = decoder_config.decoder
+
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_sub_models_pretrained(...)`"
+ )
+
+ decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ config = MusicgenConfig.from_sub_models_config(
+ text_encoder.config, audio_encoder.config, decoder.config, **kwargs
+ )
+ return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ input_values: Optional[torch.FloatTensor] = None,
+ padding_mask: Optional[torch.BoolTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> inputs = processor(
+ ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+ ... padding=True,
+ ... return_tensors="pt",
+ ... )
+
+ >>> pad_token_id = model.generation_config.pad_token_id
+ >>> decoder_input_ids = (
+ ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
+ ... * pad_token_id
+ ... )
+
+ >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
+ >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)
+ torch.Size([8, 1, 2048])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ if encoder_outputs is None:
+ encoder_outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_text_encoder,
+ )
+ elif isinstance(encoder_outputs, tuple):
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if attention_mask is not None:
+ encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ elif decoder_input_ids is None and decoder_inputs_embeds is None:
+ audio_encoder_outputs = self.audio_encoder(
+ input_values=input_values,
+ padding_mask=padding_mask,
+ **kwargs_audio_encoder,
+ )
+ audio_codes = audio_encoder_outputs.audio_codes
+ frames, bsz, codebooks, seq_len = audio_codes.shape
+ if frames != 1:
+ raise ValueError(
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+ "disabled by setting `chunk_length=None` in the audio encoder."
+ )
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ **kwargs_decoder,
+ )
+
+ loss = None
+ if labels is not None:
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ if loss is not None:
+ return (loss,) + decoder_outputs + encoder_outputs
+ else:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_attention_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ decoder_delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ if decoder_delay_pattern_mask is None:
+ decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ decoder_input_ids,
+ self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ decoder_input_ids = decoder_input_ids.repeat((2, 1))
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
+
+ # cut decoder_input_ids if past is used
+ if past_key_values is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ }
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: Dict[str, torch.Tensor],
+ decoder_start_token_id: int = None,
+ bos_token_id: int = None,
+ device: torch.device = None,
+ ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ elif "input_ids" in model_kwargs and model_input_name != "input_ids":
+ decoder_input_ids = model_kwargs.pop("input_ids")
+ else:
+ decoder_input_ids = None
+
+ # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+ decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
+ if device is None:
+ device = self.device
+ decoder_input_ids_start = (
+ torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
+ * decoder_start_token_id
+ )
+
+ # no user input -> use decoder_start_token_id as decoder_input_ids
+ if decoder_input_ids is None:
+ decoder_input_ids = decoder_input_ids_start
+
+ # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+ # decoder_attention_mask if provided)
+ elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
+ decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ decoder_attention_mask = torch.cat(
+ (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+ dim=-1,
+ )
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def _prepare_text_encoder_kwargs_for_generation(
+ self,
+ inputs_tensor: torch.Tensor,
+ model_kwargs,
+ model_input_name: Optional[str] = None,
+ guidance_scale: Optional[float] = None,
+ ) -> Dict[str, Any]:
+ # 1. get text encoder
+ encoder = self.get_text_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+ encoder_kwargs[model_input_name] = inputs_tensor
+ last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
+
+ # for classifier free guidance we need to add a 'null' input to our encoder hidden states
+ if guidance_scale is not None and guidance_scale > 1:
+ last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
+ if "attention_mask" in model_kwargs:
+ model_kwargs["attention_mask"] = torch.concatenate(
+ [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
+ )
+
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
+
+ return model_kwargs
+
+ def _prepare_audio_encoder_kwargs_for_generation(
+ self, input_values, model_kwargs, model_input_name: Optional[str] = None
+ ):
+ # 1. get audio encoder
+ encoder = self.get_audio_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+ encoder_kwargs[model_input_name] = input_values
+
+ audio_encoder_outputs = encoder.encode(**encoder_kwargs)
+
+ audio_codes = audio_encoder_outputs.audio_codes
+ frames, bsz, codebooks, seq_len = audio_codes.shape
+
+ if frames != 1:
+ raise ValueError(
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+ "disabled by setting `chunk_length=None` in the audio encoder."
+ )
+
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+ model_kwargs["decoder_input_ids"] = decoder_input_ids
+ model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales
+ return model_kwargs
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[int] = None,
+ model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.LongTensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ encoder_outputs = model_kwargs.get("encoder_outputs")
+ if encoder_outputs is not None:
+ # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
+ shape = encoder_outputs[0].size()[:-1]
+ return torch.ones(shape, dtype=torch.long, device=self.device) * -100
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, torch.Tensor):
+ batch_size = value.shape[0]
+ break
+ return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
+ - [`~generation.SampleDecoderOnlyOutput`],
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
+ - [`~generation.SampleEncoderDecoderOutput`],
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple:
+ # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
+ if model_kwargs.get("attention_mask", None) is None:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
+ generation_config.pad_token_id = eos_token_id
+
+ # 3. Define model inputs
+ # inputs_tensor has to be defined
+ # model_input_name is defined if model-specific keyword input is passed
+ # otherwise model_input_name is None
+ # all model-specific keyword inputs are removed from `model_kwargs`
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
+ )
+
+ if "encoder_outputs" not in model_kwargs:
+ # encoder_outputs are created and added to `model_kwargs`
+ model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
+ inputs_tensor,
+ model_kwargs,
+ model_input_name,
+ guidance_scale=generation_config.guidance_scale,
+ )
+
+ if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
+ model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
+ model_kwargs["input_values"],
+ model_kwargs,
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
+ device=inputs_tensor.device,
+ )
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_seq_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ logger.warning(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ if not has_default_max_length:
+ logger.warning(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
+ raise ValueError(
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
+ )
+ if input_ids_seq_length >= generation_config.max_length:
+ logger.warning(
+ f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config.decoder_start_token_id,
+ max_length=generation_config.max_length,
+ )
+ # stash the delay mask so that we don't have to recompute in each forward pass
+ model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
+
+ # 7. determine generation mode
+ is_greedy_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ )
+ is_sample_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ )
+
+ # 8. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ )
+
+ # 9. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if is_greedy_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ "num_return_sequences has to be 1 when doing greedy search, "
+ f"but is {generation_config.num_return_sequences}."
+ )
+
+ # 10. run greedy search
+ outputs = self.greedy_search(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 12. run sample
+ outputs = self.sample(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling."
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
+ batch_size, self.decoder.num_codebooks, -1
+ )
+
+ # append the frame dimension back to the audio codes
+ output_ids = output_ids[None, ...]
+
+ audio_scales = model_kwargs.get("audio_scales")
+ if audio_scales is None:
+ audio_scales = [None] * batch_size
+
+ output_values = self.audio_encoder.decode(
+ output_ids,
+ audio_scales=audio_scales,
+ )
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_values.audio_values
+ return outputs
+ else:
+ return output_values.audio_values
+
+ def get_unconditional_inputs(self, num_samples=1):
+ """
+ Helper function to get null inputs for unconditional generation, enabling the model to be used without the
+ feature extractor or tokenizer.
+
+ Args:
+ num_samples (int, *optional*):
+ Number of audio samples to unconditionally generate.
+ max_new_tokens (int, *optional*):
+ Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
+ longer inference (since more audio tokens need to be generated per sample).
+
+ Example:
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> # get the unconditional (or 'null') inputs for the model
+ >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
+ >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
+ ```"""
+ last_hidden_state = torch.zeros(
+ (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
+ )
+
+ attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
+
+ return MusicgenUnconditionalInput(
+ encoder_outputs=(last_hidden_state,),
+ attention_mask=attention_mask,
+ guidance_scale=1.0,
+ )
diff --git a/src/transformers/models/musicgen/processing_musicgen.py b/src/transformers/models/musicgen/processing_musicgen.py
new file mode 100644
index 0000000000..ed8d1277f2
--- /dev/null
+++ b/src/transformers/models/musicgen/processing_musicgen.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Text/audio processor class for MusicGen
+"""
+from typing import List, Optional
+
+import numpy as np
+
+from ...processing_utils import ProcessorMixin
+from ...utils import to_numpy
+
+
+class MusicgenProcessor(ProcessorMixin):
+ r"""
+ Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
+ class.
+
+ [`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
+ [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`EncodecFeatureExtractor`):
+ An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`T5Tokenizer`):
+ An instance of [`T5Tokenizer`]. The tokenizer is a required input.
+ """
+ feature_extractor_class = "EncodecFeatureExtractor"
+ tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
+
+ def __call__(self, *args, **kwargs):
+ """
+ Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
+ argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
+ information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ audio = kwargs.pop("audio", None)
+ sampling_rate = kwargs.pop("sampling_rate", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if text is not None:
+ inputs = self.tokenizer(text, **kwargs)
+
+ if audio is not None:
+ audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
+
+ if audio is None:
+ return inputs
+
+ elif text is None:
+ return audio_inputs
+
+ else:
+ inputs["input_values"] = audio_inputs["input_values"]
+ if "padding_mask" in audio_inputs:
+ inputs["padding_mask"] = audio_inputs["padding_mask"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
+ from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
+ [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
+ """
+ audio_values = kwargs.pop("audio", None)
+ padding_mask = kwargs.pop("padding_mask", None)
+
+ if len(args) > 0:
+ audio_values = args[0]
+ args = args[1:]
+
+ if audio_values is not None:
+ return self._decode_audio(audio_values, padding_mask=padding_mask)
+ else:
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
+ """
+ This method strips any padding from the audio values to return a list of numpy audio arrays.
+ """
+ audio_values = to_numpy(audio_values)
+ bsz, channels, seq_len = audio_values.shape
+
+ if padding_mask is None:
+ return list(audio_values)
+
+ padding_mask = to_numpy(padding_mask)
+
+ # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
+ # token (so that the generated audio values are **not** treated as padded tokens)
+ difference = seq_len - padding_mask.shape[-1]
+ padding_value = 1 - self.feature_extractor.padding_value
+ padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
+
+ audio_values = audio_values.tolist()
+ for i in range(bsz):
+ sliced_audio = np.asarray(audio_values[i])[
+ padding_mask[i][None, :] != self.feature_extractor.padding_value
+ ]
+ audio_values[i] = sliced_audio.reshape(channels, -1)
+
+ return audio_values
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 4513032607..2c40f7143d 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -4936,6 +4936,44 @@ class MT5PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
+MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class MusicgenForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MusicgenForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MusicgenModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MusicgenPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class MusicgenProcessor(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/models/musicgen/__init__.py b/tests/models/musicgen/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py
new file mode 100644
index 0000000000..00f249b09d
--- /dev/null
+++ b/tests/models/musicgen/test_modeling_musicgen.py
@@ -0,0 +1,1346 @@
+# coding=utf-8
+# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Musicgen model. """
+import copy
+import inspect
+import math
+import unittest
+
+import numpy as np
+
+from transformers import (
+ EncodecConfig,
+ MusicgenConfig,
+ MusicgenDecoderConfig,
+ MusicgenProcessor,
+ PretrainedConfig,
+ T5Config,
+)
+from transformers.testing_utils import is_torch_available, require_torch, slow, torch_device
+from transformers.utils import cached_property
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MusicgenForCausalLM,
+ MusicgenForConditionalGeneration,
+ MusicgenModel,
+ set_seed,
+ )
+ from transformers.generation import (
+ GreedySearchDecoderOnlyOutput,
+ GreedySearchEncoderDecoderOutput,
+ InfNanRemoveLogitsProcessor,
+ LogitsProcessorList,
+ SampleDecoderOnlyOutput,
+ SampleEncoderDecoderOutput,
+ )
+
+
+def _config_zero_init(config):
+ configs_no_init = copy.deepcopy(config)
+ for key in configs_no_init.__dict__.keys():
+ if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
+ setattr(configs_no_init, key, 1e-10)
+ if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
+ no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
+ setattr(configs_no_init, key, no_init_subconfig)
+ return configs_no_init
+
+
+def prepare_musicgen_decoder_inputs_dict(
+ config,
+ input_ids,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ cross_attn_head_mask=None,
+):
+ if attention_mask is None:
+ attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
+ attention_mask = attention_mask.ne(config.pad_token_id)
+ if head_mask is None:
+ head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
+ if encoder_attention_mask is None and encoder_hidden_states is not None:
+ encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
+ if cross_attn_head_mask is None:
+ cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "head_mask": head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ }
+
+
+class MusicgenDecoderTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ seq_length=7,
+ is_training=False,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=100,
+ pad_token_id=99,
+ bos_token_id=99,
+ num_codebooks=4,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.num_codebooks = num_codebooks
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
+ encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
+
+ config = self.get_config()
+ inputs_dict = prepare_musicgen_decoder_inputs_dict(
+ config, input_ids, encoder_hidden_states=encoder_hidden_states
+ )
+ return config, inputs_dict
+
+ def get_config(self):
+ config = MusicgenDecoderConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ d_ff=self.intermediate_size,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.bos_token_id,
+ bos_token_id=self.bos_token_id,
+ num_codebooks=self.num_codebooks,
+ tie_word_embeddings=False,
+ )
+ return config
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+
+@require_torch
+class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (MusicgenModel, MusicgenForCausalLM) if is_torch_available() else ()
+ greedy_sample_model_classes = (
+ (MusicgenForCausalLM,) if is_torch_available() else ()
+ ) # we don't want to run all the generation tests, only a specific subset
+ pipeline_model_mapping = {}
+ test_pruning = False
+ test_resize_embeddings = False
+
+ def setUp(self):
+ self.model_tester = MusicgenDecoderTester(self)
+ self.config_tester = ConfigTester(self, config_class=MusicgenDecoderConfig, hidden_size=16)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ # override since we have to compute the input embeddings over codebooks
+ def test_inputs_embeds(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
+
+ input_ids = inputs["input_ids"]
+ del inputs["input_ids"]
+
+ embed_tokens = model.get_input_embeddings()
+
+ input_ids = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])
+
+ inputs["inputs_embeds"] = sum(
+ [embed_tokens[codebook](input_ids[:, codebook]) for codebook in range(config.num_codebooks)]
+ )
+
+ with torch.no_grad():
+ model(**inputs)[0]
+
+ # override since we have embeddings / LM heads over multiple codebooks
+ def test_model_common_attributes(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ first_embed = model.get_input_embeddings()[0]
+ self.assertIsInstance(first_embed, torch.nn.Embedding)
+ lm_heads = model.get_output_embeddings()
+ self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
+
+ # skip as this model doesn't support all arguments tested
+ def test_model_outputs_equivalence(self):
+ pass
+
+ # skip as this model has multiple inputs embeds and lm heads that should not be tied
+ def test_tie_model_weights(self):
+ pass
+
+ # skip as this model has multiple inputs embeds and lm heads that should not be tied
+ def test_tied_weights_keys(self):
+ pass
+
+ def _get_input_ids_and_config(self, batch_size=2):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ input_ids = inputs_dict["input_ids"]
+
+ # take max batch_size
+ sequence_length = input_ids.shape[-1]
+ input_ids = input_ids[: batch_size * config.num_codebooks, :]
+
+ # generate max 3 tokens
+ max_length = input_ids.shape[-1] + 3
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
+ return config, input_ids, attention_mask, max_length
+
+ @staticmethod
+ def _get_logits_processor_and_kwargs(
+ input_length,
+ eos_token_id,
+ forced_bos_token_id=None,
+ forced_eos_token_id=None,
+ max_length=None,
+ diversity_penalty=None,
+ ):
+ process_kwargs = {
+ "min_length": input_length + 1 if max_length is None else max_length - 1,
+ }
+ logits_processor = LogitsProcessorList()
+ return process_kwargs, logits_processor
+
+ # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
+ # additional post-processing in the former
+ def test_greedy_generate_dict_outputs(self):
+ for model_class in self.greedy_sample_model_classes:
+ # disable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
+ model = model_class(config).to(torch_device).eval()
+ output_greedy, output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
+ self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
+
+ self.assertNotIn(config.pad_token_id, output_generate)
+
+ # override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
+ # additional post-processing in the former
+ def test_greedy_generate_dict_outputs_use_cache(self):
+ for model_class in self.greedy_sample_model_classes:
+ # enable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+
+ config.use_cache = True
+ config.is_decoder = True
+ model = model_class(config).to(torch_device).eval()
+ output_greedy, output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
+ self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
+
+ # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
+ # additional post-processing in the former
+ def test_sample_generate(self):
+ for model_class in self.greedy_sample_model_classes:
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
+
+ process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
+
+ # check `generate()` and `sample()` are equal
+ output_sample, output_generate = self._sample_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ num_return_sequences=3,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
+ process_kwargs=process_kwargs,
+ )
+ self.assertIsInstance(output_sample, torch.Tensor)
+ self.assertIsInstance(output_generate, torch.Tensor)
+
+ # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
+ # additional post-processing in the former
+ def test_sample_generate_dict_output(self):
+ for model_class in self.greedy_sample_model_classes:
+ # disable cache
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
+ model = model_class(config).to(torch_device).eval()
+
+ process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
+
+ output_sample, output_generate = self._sample_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_length=max_length,
+ num_return_sequences=1,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
+ process_kwargs=process_kwargs,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
+ self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
+
+
+def prepare_musicgen_inputs_dict(
+ config,
+ input_ids,
+ decoder_input_ids,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+):
+ if decoder_attention_mask is None:
+ decoder_attention_mask = decoder_input_ids.reshape(
+ -1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
+ )[:, 0, :]
+ decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
+ if head_mask is None:
+ head_mask = torch.ones(
+ config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
+ )
+ if decoder_head_mask is None:
+ decoder_head_mask = torch.ones(
+ config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
+ )
+ if cross_attn_head_mask is None:
+ cross_attn_head_mask = torch.ones(
+ config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
+ )
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ }
+
+
+class MusicgenTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ seq_length=7,
+ is_training=False,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=100,
+ pad_token_id=99,
+ bos_token_id=99,
+ num_codebooks=4,
+ num_filters=4,
+ codebook_size=128,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.num_codebooks = num_codebooks
+ self.num_filters = num_filters
+ self.codebook_size = codebook_size
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ decoder_input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
+
+ config = self.get_config()
+ inputs_dict = prepare_musicgen_inputs_dict(config, input_ids, decoder_input_ids=decoder_input_ids)
+ return config, inputs_dict
+
+ def get_config(self):
+ text_encoder_config = T5Config(
+ vocab_size=self.vocab_size,
+ d_model=self.hidden_size,
+ d_ff=self.intermediate_size,
+ num_layers=self.num_hidden_layers,
+ num_heads=self.num_attention_heads,
+ )
+ audio_encoder_config = EncodecConfig(
+ hidden_size=self.vocab_size,
+ compress=1,
+ num_filters=self.num_filters,
+ codebook_size=self.codebook_size,
+ codebook_dim=self.vocab_size,
+ )
+ decoder_config = MusicgenDecoderConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ pad_token_id=self.pad_token_id,
+ decoder_start_token_id=self.bos_token_id,
+ bos_token_id=self.bos_token_id,
+ num_codebooks=self.num_codebooks,
+ tie_word_embeddings=False,
+ )
+ config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
+ return config
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+
+@require_torch
+class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ all_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
+ greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
+ pipeline_model_mapping = {}
+ test_pruning = False # training is not supported yet for MusicGen
+ test_headmasking = False
+ test_resize_embeddings = False
+
+ def setUp(self):
+ self.model_tester = MusicgenTester(self)
+
+ def _check_output_with_attentions(self, outputs, config, input_ids, decoder_input_ids):
+ text_encoder_config = config.text_encoder
+ decoder_config = config.decoder
+
+ encoder_attentions = outputs["encoder_attentions"]
+ self.assertEqual(len(encoder_attentions), text_encoder_config.num_hidden_layers)
+
+ self.assertEqual(
+ encoder_attentions[0].shape[-3:],
+ (text_encoder_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
+ )
+
+ decoder_attentions = outputs["decoder_attentions"]
+ num_decoder_layers = decoder_config.num_hidden_layers
+ self.assertEqual(len(decoder_attentions), num_decoder_layers)
+
+ self.assertEqual(
+ decoder_attentions[0].shape[-3:],
+ (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
+ )
+
+ cross_attentions = outputs["cross_attentions"]
+ self.assertEqual(len(cross_attentions), num_decoder_layers)
+
+ cross_attention_input_seq_len = decoder_input_ids.shape[-1]
+ self.assertEqual(
+ cross_attentions[0].shape[-3:],
+ (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
+ )
+
+ def check_musicgen_model_output_attentions(
+ self,
+ model_class,
+ config,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ ):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ output_attentions=True,
+ **kwargs,
+ )
+ self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids)
+
+ def check_musicgen_model_output_attentions_from_config(
+ self,
+ model_class,
+ config,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ **kwargs,
+ ):
+ # Similar to `check_musicgen_model_output_attentions`, but with `output_attentions` triggered from the
+ # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
+ # from the inner models' configurations.
+ config.output_attentions = True # model config -> won't work
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ **kwargs,
+ )
+ self.assertTrue(
+ all(key not in outputs for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"])
+ )
+ config.text_encoder.output_attentions = True # inner model config -> will work
+ config.audio_encoder.output_attentions = True
+ config.decoder.output_attentions = True
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ decoder_input_ids=decoder_input_ids,
+ attention_mask=attention_mask,
+ decoder_attention_mask=decoder_attention_mask,
+ **kwargs,
+ )
+ self._check_output_with_attentions(outputs, config, input_ids, decoder_input_ids)
+
+ # override since changing `output_attentions` from the top-level model config won't work
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict)
+ self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict)
+
+ # override since we have a specific forward signature for musicgen
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = [
+ "input_ids",
+ "attention_mask",
+ "input_values",
+ "padding_mask",
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ ]
+ expected_arg_names.extend(
+ ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
+ if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
+ else ["encoder_outputs"]
+ )
+ self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
+
+ # override since changing `gradient_checkpointing` from the top-level model config won't work
+ def test_gradient_checkpointing_backward_compatibility(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if not model_class.supports_gradient_checkpointing:
+ continue
+
+ config.text_encoder.gradient_checkpointing = True
+ config.audio_encoder.gradient_checkpointing = True
+ config.decoder.gradient_checkpointing = True
+ model = model_class(config)
+ self.assertTrue(model.is_gradient_checkpointing)
+
+ # skip as this model has multiple inputs embeds and lm heads that should not be tied
+ def test_tie_model_weights(self):
+ pass
+
+ # skip as this model has multiple inputs embeds and lm heads that should not be tied
+ def test_tied_model_weights_key_ignore(self):
+ pass
+
+ # skip as this model has multiple inputs embeds and lm heads that should not be tied
+ def test_tied_weights_keys(self):
+ pass
+
+ # override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.text_encoder.output_hidden_states = True
+ config.audio_encoder.output_hidden_states = True
+ config.decoder.output_hidden_states = True
+
+ config.text_encoder.output_attentions = True
+ config.decoder.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+
+ outputs = model(**inputs)
+
+ output = outputs[0]
+
+ encoder_hidden_states = outputs.encoder_hidden_states[0]
+ encoder_hidden_states.retain_grad()
+
+ decoder_hidden_states = outputs.decoder_hidden_states[0]
+ decoder_hidden_states.retain_grad()
+
+ if self.has_attentions:
+ encoder_attentions = outputs.encoder_attentions[0]
+ encoder_attentions.retain_grad()
+
+ decoder_attentions = outputs.decoder_attentions[0]
+ decoder_attentions.retain_grad()
+
+ cross_attentions = outputs.cross_attentions[0]
+ cross_attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ self.assertIsNotNone(encoder_hidden_states.grad)
+ self.assertIsNotNone(decoder_hidden_states.grad)
+
+ if self.has_attentions:
+ self.assertIsNotNone(encoder_attentions.grad)
+ self.assertIsNotNone(decoder_attentions.grad)
+ self.assertIsNotNone(cross_attentions.grad)
+
+ # override since changing `output_hidden_states` from the top-level model config won't work
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states
+
+ expected_num_layers = self.model_tester.num_hidden_layers + 1
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ seq_length = self.model_tester.seq_length
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ hidden_states = outputs.decoder_hidden_states
+ self.assertIsInstance(hidden_states, (list, tuple))
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.text_encoder.output_hidden_states = True
+ config.audio_encoder.output_hidden_states = True
+ config.decoder.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # override since the conv layers and lstm's in encodec are exceptions
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ uniform_init_parms = ["conv"]
+ ignore_init = ["lstm"]
+ if param.requires_grad:
+ if any([x in name for x in uniform_init_parms]):
+ self.assertTrue(
+ -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ elif not any([x in name for x in ignore_init]):
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ # override since we have embeddings / LM heads over multiple codebooks
+ def test_model_common_attributes(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), torch.nn.Embedding)
+ lm_heads = model.get_output_embeddings()
+ self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
+
+ def _get_input_ids_and_config(self, batch_size=2):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ input_ids = inputs_dict["input_ids"]
+
+ # take max batch_size
+ sequence_length = input_ids.shape[-1]
+ input_ids = input_ids[:batch_size, :]
+ attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
+
+ # generate max 3 tokens
+ decoder_input_ids = inputs_dict["decoder_input_ids"]
+ max_length = decoder_input_ids.shape[-1] + 3
+ decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :]
+ return config, input_ids, attention_mask, decoder_input_ids, max_length
+
+ # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
+ # different modalities -> different shapes)
+ def _greedy_generate(
+ self,
+ model,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ max_length,
+ output_scores=False,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict_in_generate=False,
+ ):
+ logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ eos_token_id=model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+
+ model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ output_generate = model.generate(
+ input_ids,
+ do_sample=False,
+ num_beams=1,
+ max_length=max_length,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_scores=output_scores,
+ return_dict_in_generate=return_dict_in_generate,
+ remove_invalid_values=True,
+ **logits_process_kwargs,
+ **model_kwargs,
+ )
+
+ encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
+ model,
+ input_ids,
+ attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ with torch.no_grad():
+ model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ output_greedy = model.greedy_search(
+ decoder_input_ids,
+ max_length=max_length,
+ logits_processor=logits_processor,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_scores=output_scores,
+ return_dict_in_generate=return_dict_in_generate,
+ encoder_outputs=encoder_outputs,
+ **model_kwargs,
+ )
+ return output_greedy, output_generate
+
+ # override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
+ # different modalities -> different shapes)
+ def _sample_generate(
+ self,
+ model,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ max_length,
+ num_return_sequences,
+ logits_processor,
+ logits_warper,
+ logits_warper_kwargs,
+ process_kwargs,
+ output_scores=False,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict_in_generate=False,
+ ):
+ torch.manual_seed(0)
+ model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ output_generate = model.generate(
+ input_ids,
+ do_sample=True,
+ num_beams=1,
+ max_length=max_length,
+ num_return_sequences=num_return_sequences,
+ output_scores=output_scores,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict_in_generate=return_dict_in_generate,
+ remove_invalid_values=True,
+ **logits_warper_kwargs,
+ **process_kwargs,
+ **model_kwargs,
+ )
+
+ torch.manual_seed(0)
+ encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
+ model,
+ input_ids,
+ attention_mask,
+ num_interleave=num_return_sequences,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ # prevent flaky generation test failures
+ logits_processor.append(InfNanRemoveLogitsProcessor())
+
+ with torch.no_grad():
+ model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
+ output_sample = model.sample(
+ decoder_input_ids.repeat_interleave(num_return_sequences, dim=0),
+ max_length=max_length,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ output_scores=output_scores,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict_in_generate=return_dict_in_generate,
+ encoder_outputs=encoder_outputs,
+ **model_kwargs,
+ )
+
+ return output_sample, output_generate
+
+ @staticmethod
+ def _get_logits_processor_and_kwargs(
+ input_length,
+ eos_token_id,
+ forced_bos_token_id=None,
+ forced_eos_token_id=None,
+ max_length=None,
+ diversity_penalty=None,
+ ):
+ process_kwargs = {
+ "min_length": input_length + 1 if max_length is None else max_length - 1,
+ }
+ logits_processor = LogitsProcessorList()
+ return process_kwargs, logits_processor
+
+ def test_greedy_generate_dict_outputs(self):
+ for model_class in self.greedy_sample_model_classes:
+ # disable cache
+ config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
+ model = model_class(config).to(torch_device).eval()
+ output_greedy, output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ max_length=max_length,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
+ self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
+
+ self.assertNotIn(config.pad_token_id, output_generate)
+
+ def test_greedy_generate_dict_outputs_use_cache(self):
+ for model_class in self.greedy_sample_model_classes:
+ # enable cache
+ config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
+
+ config.use_cache = True
+ config.is_decoder = True
+ model = model_class(config).to(torch_device).eval()
+ output_greedy, output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ max_length=max_length,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
+ self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
+
+ def test_sample_generate(self):
+ for model_class in self.greedy_sample_model_classes:
+ config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
+
+ process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
+
+ # check `generate()` and `sample()` are equal
+ output_sample, output_generate = self._sample_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ max_length=max_length,
+ num_return_sequences=1,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
+ process_kwargs=process_kwargs,
+ )
+ self.assertIsInstance(output_sample, torch.Tensor)
+ self.assertIsInstance(output_generate, torch.Tensor)
+
+ def test_sample_generate_dict_output(self):
+ for model_class in self.greedy_sample_model_classes:
+ # disable cache
+ config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
+ config.use_cache = False
+ model = model_class(config).to(torch_device).eval()
+
+ process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
+ input_ids.shape[-1],
+ model.config.eos_token_id,
+ forced_bos_token_id=model.config.forced_bos_token_id,
+ forced_eos_token_id=model.config.forced_eos_token_id,
+ max_length=max_length,
+ )
+ logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
+
+ output_sample, output_generate = self._sample_generate(
+ model=model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ max_length=max_length,
+ num_return_sequences=3,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ logits_warper_kwargs=logits_warper_kwargs,
+ process_kwargs=process_kwargs,
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
+ self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
+
+ def test_generate_without_input_ids(self):
+ config, _, _, _, max_length = self._get_input_ids_and_config()
+
+ # if no bos token id => cannot generate from None
+ if config.bos_token_id is None:
+ return
+
+ for model_class in self.greedy_sample_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+
+ output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
+ self.assertIsNotNone(output_ids_generate)
+
+ def test_generate_fp16(self):
+ config, input_dict = self.model_tester.prepare_config_and_inputs()
+
+ for model_class in self.greedy_sample_model_classes:
+ model = model_class(config).eval().to(torch_device)
+ if torch_device == "cuda":
+ model.half()
+ model.generate(**input_dict, max_new_tokens=10)
+ model.generate(**input_dict, do_sample=True, max_new_tokens=10)
+
+
+def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
+ """Produces a series of 'bip bip' sounds at a given frequency."""
+ timesteps = np.arange(int(duration * sample_rate)) / sample_rate
+ wav = np.cos(2 * math.pi * 440 * timesteps)
+ time_period = (timesteps % (2 * bip_duration)) / (2 * bip_duration)
+ envelope = time_period >= 0.5
+ return wav * envelope
+
+
+def place_dict_on_device(dict_to_place, device):
+ for key in dict_to_place:
+ if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor):
+ dict_to_place[key] = dict_to_place[key].to(device)
+ return dict_to_place
+
+
+@require_torch
+class MusicgenIntegrationTests(unittest.TestCase):
+ @cached_property
+ def model(self):
+ return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to(torch_device)
+
+ @cached_property
+ def processor(self):
+ return MusicgenProcessor.from_pretrained("facebook/musicgen-small")
+
+ @slow
+ def test_logits_text_prompt(self):
+ model = self.model
+ processor = self.processor
+
+ inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
+
+ # prepare the encoder inputs
+ input_ids = inputs.input_ids.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ # prepare the decoder inputs
+ pad_token_id = model.generation_config.pad_token_id
+ decoder_input_ids = (
+ torch.ones((input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long).to(torch_device)
+ * pad_token_id
+ )
+
+ with torch.no_grad():
+ logits = model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ ).logits
+
+ # fmt: off
+ EXPECTED_LOGITS = torch.tensor(
+ [
+ -0.9708, -3.0149, -4.6415, -1.4754, -0.2786, -2.3523, -2.6049, -6.7467,
+ -1.0206, -3.2984, -3.3968, -1.5108, -1.5786, -3.1493, -1.1503, -0.0545,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(logits.shape == (*decoder_input_ids.shape, model.decoder.config.vocab_size))
+ self.assertTrue(torch.allclose(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, atol=1e-4))
+
+ @slow
+ def test_logits_text_audio_prompt(self):
+ model = self.model
+ processor = self.processor
+
+ audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)]
+ text = ["80s music", "Club techno"]
+
+ inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
+
+ # prepare the text encoder inputs
+ input_ids = inputs.input_ids.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ # prepare the audio encoder inputs
+ input_values = inputs.input_values.to(torch_device)
+ padding_mask = inputs.padding_mask.to(torch_device)
+
+ with torch.no_grad():
+ logits = model(
+ input_ids,
+ attention_mask=attention_mask,
+ input_values=input_values,
+ padding_mask=padding_mask,
+ ).logits
+
+ # fmt: off
+ EXPECTED_LOGITS = torch.tensor(
+ [
+ 0.1841, -2.9324, -0.7898, 0.1857, 0.4971, -2.8685, -1.6525, -1.6541,
+ 2.7757, -2.5942, -3.0959, -1.0120, -1.0147, -0.4605, -0.8885, 0.6820,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(logits.shape == (8, 50, 2048))
+ self.assertTrue(torch.allclose(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, atol=1e-4))
+
+ @slow
+ def test_generate_unconditional_greedy(self):
+ model = self.model
+
+ # only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
+ unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
+ unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
+
+ output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=5)
+
+ # fmt: off
+ EXPECTED_VALUES = torch.tensor(
+ [
+ 0.0056, 0.0064, 0.0063, 0.0054, 0.0042, 0.0033, 0.0024, 0.0015,
+ 0.0015, 0.0010, 0.0004, -0.0012, -0.0036, -0.0055, -0.0067, -0.0071,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(output_values.shape == (1, 1, 3200))
+ self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
+
+ @slow
+ def test_generate_unconditional_sampling(self):
+ model = self.model
+
+ # for stochastic sampling we can generate multiple outputs
+ unconditional_inputs = model.get_unconditional_inputs(num_samples=2)
+ unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
+
+ set_seed(0)
+ output_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=10)
+
+ # fmt: off
+ EXPECTED_VALUES = torch.tensor(
+ [
+ 0.0765, 0.0758, 0.0749, 0.0759, 0.0759, 0.0771, 0.0775, 0.0760,
+ 0.0762, 0.0765, 0.0767, 0.0760, 0.0738, 0.0714, 0.0713, 0.0730,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(output_values.shape == (2, 1, 4480))
+ self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
+
+ @slow
+ def test_generate_text_prompt_greedy(self):
+ model = self.model
+ processor = self.processor
+
+ inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
+
+ # prepare the encoder inputs
+ input_ids = inputs.input_ids.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ output_values = model.generate(
+ input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=None, max_new_tokens=10
+ )
+
+ # fmt: off
+ EXPECTED_VALUES = torch.tensor(
+ [
+ -1.1998e-04, -2.2302e-04, 4.6296e-04, 1.0524e-03, 2.4827e-04,
+ -4.0288e-05, -1.2468e-04, 4.9846e-05, 7.1485e-04, 4.4197e-04,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(output_values.shape == (2, 1, 4480))
+ self.assertTrue(torch.allclose(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, atol=1e-4))
+
+ @slow
+ def test_generate_text_prompt_greedy_with_classifier_free_guidance(self):
+ model = self.model
+ processor = self.processor
+
+ inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
+
+ # prepare the encoder inputs
+ input_ids = inputs.input_ids.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ output_values = model.generate(
+ input_ids, attention_mask=attention_mask, do_sample=False, guidance_scale=3, max_new_tokens=10
+ )
+
+ # fmt: off
+ EXPECTED_VALUES = torch.tensor(
+ [
+ 0.0283, 0.0246, 0.0650, 0.0640, 0.0599, 0.0711, 0.0420, 0.0112,
+ 0.0511, 0.0746, 0.1363, 0.1213, 0.0185, -0.0578, -0.0908, 0.0443,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(output_values.shape == (2, 1, 4480))
+ self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
+
+ @slow
+ def test_generate_text_prompt_sampling(self):
+ model = self.model
+ processor = self.processor
+
+ inputs = processor(text=["80s music", "Club techno"], padding=True, return_tensors="pt")
+
+ # prepare the encoder inputs
+ input_ids = inputs.input_ids.to(torch_device)
+ attention_mask = inputs.attention_mask.to(torch_device)
+
+ set_seed(0)
+ output_values = model.generate(
+ input_ids, attention_mask=attention_mask, do_sample=True, guidance_scale=None, max_new_tokens=10
+ )
+
+ # fmt: off
+ EXPECTED_VALUES = torch.tensor(
+ [
+ -0.0047, -0.0094, -0.0028, -0.0018, -0.0057, -0.0007, -0.0104, -0.0211,
+ -0.0097, -0.0150, -0.0066, -0.0004, -0.0201, -0.0325, -0.0326, -0.0098,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(output_values.shape == (2, 1, 4480))
+ self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
+
+ @slow
+ def test_generate_text_audio_prompt(self):
+ model = self.model
+ processor = self.processor
+
+ audio = [get_bip_bip(duration=0.5), get_bip_bip(duration=1.0)]
+ text = ["80s music", "Club techno"]
+
+ inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
+ inputs = place_dict_on_device(inputs, device=torch_device)
+
+ output_values = model.generate(**inputs, do_sample=False, guidance_scale=None, max_new_tokens=10)
+
+ # fmt: off
+ EXPECTED_VALUES = torch.tensor(
+ [
+ -0.0036, -0.0130, -0.0261, -0.0384, -0.0557, -0.0718, -0.0680, -0.0632,
+ -0.0529, -0.0403, -0.0289, -0.0198, -0.0136, -0.0101, -0.0095, -0.0040,
+ ]
+ )
+ # fmt: on
+
+ self.assertTrue(
+ output_values.shape == (2, 1, 36480)
+ ) # input values take shape 32000 and we generate from there
+ self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
diff --git a/tests/models/musicgen/test_processing_musicgen.py b/tests/models/musicgen/test_processing_musicgen.py
new file mode 100644
index 0000000000..41962552ff
--- /dev/null
+++ b/tests/models/musicgen/test_processing_musicgen.py
@@ -0,0 +1,173 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for the MusicGen processor."""
+
+import random
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+
+from transformers import T5Tokenizer, T5TokenizerFast
+from transformers.testing_utils import require_sentencepiece, require_torch
+from transformers.utils.import_utils import is_speech_available, is_torch_available
+
+
+if is_torch_available():
+ pass
+
+if is_speech_available():
+ from transformers import EncodecFeatureExtractor, MusicgenProcessor
+
+
+global_rng = random.Random()
+
+
+def floats_list(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ values = []
+ for batch_idx in range(shape[0]):
+ values.append([])
+ for _ in range(shape[1]):
+ values[-1].append(rng.random() * scale)
+
+ return values
+
+
+@require_torch
+@require_sentencepiece
+class MusicgenProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.checkpoint = "facebook/musicgen-small"
+ self.tmpdirname = tempfile.mkdtemp()
+
+ def get_tokenizer(self, **kwargs):
+ return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return EncodecFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = MusicgenProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+ self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = MusicgenProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = MusicgenProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ raw_speech = floats_list((3, 1000))
+
+ input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
+ input_processor = processor(raw_speech, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "This is a test string"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(sequences=predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
+
+ def test_model_input_names(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ self.assertListEqual(
+ processor.model_input_names,
+ feature_extractor.model_input_names,
+ msg="`processor` and `feature_extractor` model input names do not match",
+ )
+
+ def test_decode_audio(self):
+ feature_extractor = self.get_feature_extractor(padding_side="left")
+ tokenizer = self.get_tokenizer()
+
+ processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ raw_speech = [floats_list((1, x))[0] for x in range(5, 20, 5)]
+ padding_mask = processor(raw_speech).padding_mask
+
+ generated_speech = np.asarray(floats_list((3, 20)))[:, None, :]
+ decoded_audios = processor.batch_decode(generated_speech, padding_mask=padding_mask)
+
+ self.assertIsInstance(decoded_audios, list)
+
+ for audio in decoded_audios:
+ self.assertIsInstance(audio, np.ndarray)
+
+ self.assertTrue(decoded_audios[0].shape == (1, 10))
+ self.assertTrue(decoded_audios[1].shape == (1, 15))
+ self.assertTrue(decoded_audios[2].shape == (1, 20))
diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py
index 93385b127d..de47348a9e 100644
--- a/utils/check_config_docstrings.py
+++ b/utils/check_config_docstrings.py
@@ -37,6 +37,7 @@ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"DecisionTransformerConfig",
"EncoderDecoderConfig",
+ "MusicgenConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 2de9fa61e1..46407eb1a5 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -119,6 +119,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
+ "MusicgenDecoder", # Building part of bigger (tested) model.
"MvpDecoderWrapper", # Building part of bigger (tested) model.
"MvpEncoder", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model.
@@ -331,6 +332,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"SpeechT5ForSpeechToSpeech",
"SpeechT5ForTextToSpeech",
"SpeechT5HifiGan",
+ "MusicgenModel",
+ "MusicgenForConditionalGeneration",
]
# Update this list for models that have multiple model types for the same
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 0064a9999b..2bbaabb073 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -134,6 +134,9 @@ src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py
src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py
src/transformers/models/mobilevit/modeling_mobilevit.py
src/transformers/models/mobilevit/modeling_tf_mobilevit.py
+src/transformers/models/musicgen/modeling_musicgen.py
+src/transformers/models/musicgen/configuration_musicgen.py
+src/transformers/models/musicgen/processing_musicgen.py
src/transformers/models/mvp/configuration_mvp.py
src/transformers/models/nat/configuration_nat.py
src/transformers/models/nat/modeling_nat.py