Add Musicgen (#24109)
* Add Audiocraft * add cross attention * style * add for lm * convert and verify * introduce t5 * split configs * load t5 + lm * clean conversion * copy from t5 * style * start pattern provider * make generation work * style * fix pos embs * propagate shape changes * propagate shape changes * style * delay pattern: pad tokens at end * audiocraft -> musicgen * fix inits * add mdx * style * fix pad token in processor * override generate and add todos * add init to test * undo pattern delay mask after gen * remove cfg logits processor * remove cfg logits processor * remove logits processor in favour of mask * clean pos embs * make fix copies * update readmes * clean pos emb * refactor encoder/decoder * make fix copies * update conversion * fix config imports * update config docs * make style * send pattern mask to device * pattern mask with delay * recover prompted audio tokens * fix docstrings * laydown test file * pattern edge case * remove t5 ref * add processing class * config refactor * better pattern comment * check if mask is not present * check if mask is not present * refactor to auto class * remove encoder configs * fix processor * processor import * start updating conversion * start updating tests * make style * convert t5, encodec, lm * convert as composite * also convert processor * run generate * classifier free gen * comments and clean up * make style * docs for logit proc * docstring for uncond gen * start lm tests * work tests * let the lm generate * refactor: reshape inside forward * undo greedy loop changes * from_enc_dec -> from_sub_model * fix input id shapes in docstrings * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * undo generate changes * from sub model config * Update src/transformers/models/musicgen/modeling_musicgen.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make generate work again * generate uncond -> get uncond inputs * remove prefix allowed tokens fn * better error message * logit proc checks * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * make decoder only tests work * composite fast tests * make style * uncond generation * feat extr padding * make audio prompt work * fix inputs docstrings * unconditional inputs: dict -> model output * clean up tests * more clean up tests * make style * t5 encoder -> auto text encoder * remove comments * deal with frames * fix auto text * slow tests * nice mdx * remove can generate * todo - hub id * convert m/l * make fix copies * only import generation with torch * ignore decoder from tests * don't wrap uncond inputs * make style * cleaner uncond inputs * add example to musicgen forward * fix docs * ignore MusicGen Model/ForConditionalGeneration in auto mapping * add doc section to toctree * add to doc tests * add processor tests * fix push to hub in conversion * tips for decoder only loading * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix conversion for s / m / l checkpoints * import stopping criteria from module * remove from pipeline tests * fix uncond docstring * decode audio method * fix docs * org: sanchit-gandhi -> facebook * fix max pos embeddings * remove auto doc (not compatible with shapes) * bump max pos emb * make style * fix doc * fix config doc * fix config doc * ignore musicgen config from docstring * make style * fix config * fix config for doctest * consistent from_sub_models * don't automap decoder * fix mdx save audio file * fix mdx save audio file * processor batch decode for audio * remove keys to ignore * update doc md * update generation config * allow changes for default generation config * update tests * make style * fix docstring for uncond * fix processor test * fix processor test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -411,6 +411,7 @@ Current number of checkpoints: ** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
|
||||
|
||||
@@ -386,6 +386,7 @@ Número actual de puntos de control: ** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
|
||||
|
||||
@@ -358,6 +358,7 @@ conda install -c huggingface transformers
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple से) Sachin Mehta and Mohammad Rastegari. द्वाराअनुसंधान पत्र [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) के साथ जारी किया गया
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI से) साथ वाला पेपर [mT5: एक व्यापक बहुभाषी पूर्व-प्रशिक्षित टेक्स्ट-टू-टेक्स्ट ट्रांसफॉर्मर]( https://arxiv.org/abs/2010.11934) लिंटिंग ज़ू, नोआ कॉन्सटेंट, एडम रॉबर्ट्स, मिहिर काले, रामी अल-रफू, आदित्य सिद्धांत, आदित्य बरुआ, कॉलिन रैफेल द्वारा पोस्ट किया गया।
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (हुआवेई नूह के आर्क लैब से) साथ में कागज़ [NEZHA: चीनी भाषा समझ के लिए तंत्रिका प्रासंगिक प्रतिनिधित्व](https :/ /arxiv.org/abs/1909.00204) जुन्किउ वेई, ज़ियाओज़े रेन, ज़िआओगुआंग ली, वेनयोंग हुआंग, यी लियाओ, याशेंग वांग, जियाशू लिन, शिन जियांग, जिओ चेन और कुन लियू द्वारा।
|
||||
|
||||
@@ -420,6 +420,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple から) Sachin Mehta and Mohammad Rastegari. から公開された研究論文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research から) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu から公開された研究論文: [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297)
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI から) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel から公開された研究論文: [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934)
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box から) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen から公開された研究論文: [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131)
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (SHI Labs から) Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi から公開された研究論文: [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143)
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab から) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu から公開された研究論文: [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204)
|
||||
|
||||
@@ -335,6 +335,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple 에서 제공)은 Sachin Mehta and Mohammad Rastegari.의 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)논문과 함께 발표했습니다.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research 에서) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 의 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 논문과 함께 발표했습니다.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI 에서) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 의 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 논문과 함께 발표했습니다.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box 에서) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 의 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 논문과 함께 발표했습니다.
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (SHI Labs 에서) Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi 의 [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 논문과 함께 발표했습니다.
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (Huawei Noah’s Ark Lab 에서) Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 의 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 논문과 함께 발표했습니다.
|
||||
|
||||
@@ -359,6 +359,7 @@ conda install -c huggingface transformers
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (来自 Apple) 伴随论文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) 由 Sachin Mehta and Mohammad Rastegari 发布。
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (来自 中国人民大学 AI Box) 伴随论文 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 由 Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 发布。
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (来自 SHI Labs) 伴随论文 [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) 由 Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi 发布。
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (来自华为诺亚方舟实验室) 伴随论文 [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) 由 Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu 发布。
|
||||
|
||||
@@ -371,6 +371,7 @@ conda install -c huggingface transformers
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
1. **[NAT](https://huggingface.co/docs/transformers/model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
|
||||
1. **[Nezha](https://huggingface.co/docs/transformers/model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
|
||||
|
||||
@@ -549,6 +549,8 @@
|
||||
title: MCTCT
|
||||
- local: model_doc/mms
|
||||
title: MMS
|
||||
- local: model_doc/musicgen
|
||||
title: MusicGen
|
||||
- local: model_doc/sew
|
||||
title: SEW
|
||||
- local: model_doc/sew-d
|
||||
|
||||
@@ -175,6 +175,7 @@ The documentation is organized into five sections:
|
||||
1. **[MobileViTV2](model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
1. **[NAT](model_doc/nat)** (from SHI Labs) released with the paper [Neighborhood Attention Transformer](https://arxiv.org/abs/2204.07143) by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
|
||||
1. **[Nezha](model_doc/nezha)** (from Huawei Noah’s Ark Lab) released with the paper [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://arxiv.org/abs/1909.00204) by Junqiu Wei, Xiaozhe Ren, Xiaoguang Li, Wenyong Huang, Yi Liao, Yasheng Wang, Jiashu Lin, Xin Jiang, Xiao Chen and Qun Liu.
|
||||
@@ -380,6 +381,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| MobileViTV2 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MPNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| MT5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| MusicGen | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MVP | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| NAT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Nezha | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
277
docs/source/en/model_doc/musicgen.md
Normal file
277
docs/source/en/model_doc/musicgen.md
Normal file
@@ -0,0 +1,277 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# MusicGen
|
||||
|
||||
## Overview
|
||||
|
||||
The MusicGen model was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284)
|
||||
by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
|
||||
MusicGen is a single stage auto-regressive Transformer model capable of generating high-quality music samples conditioned
|
||||
on text descriptions or audio prompts. The text descriptions are passed through a frozen text encoder model to obtain a
|
||||
sequence of hidden-state representations. MusicGen is then trained to predict discrete audio tokens, or *audio codes*,
|
||||
conditioned on these hidden-states. These audio tokens are then decoded using an audio compression model, such as EnCodec,
|
||||
to recover the audio waveform.
|
||||
|
||||
Through an efficient token interleaving pattern, MusicGen does not require a self-supervised semantic representation of
|
||||
the text/audio prompts, thus eliminating the need to cascade multiple models to predict a set of codebooks (e.g.
|
||||
hierarchically or upsampling). Instead, it is able to generate all the codebooks in a single forward pass.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We tackle the task of conditional music generation. We introduce MusicGen, a single Language Model (LM) that operates
|
||||
over several streams of compressed discrete music representation, i.e., tokens. Unlike prior work, MusicGen is comprised
|
||||
of a single-stage transformer LM together with efficient token interleaving patterns, which eliminates the need for
|
||||
cascading several models, e.g., hierarchically or upsampling. Following this approach, we demonstrate how MusicGen
|
||||
can generate high-quality samples, while being conditioned on textual description or melodic features, allowing better
|
||||
controls over the generated output. We conduct extensive empirical evaluation, considering both automatic and human
|
||||
studies, showing the proposed approach is superior to the evaluated baselines on a standard text-to-music benchmark.
|
||||
Through ablation studies, we shed light over the importance of each of the components comprising MusicGen.*
|
||||
|
||||
This model was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original code can be found
|
||||
[here](https://github.com/facebookresearch/audiocraft). The pre-trained checkpoints can be found on the
|
||||
[Hugging Face Hub](https://huggingface.co/models?sort=downloads&search=facebook%2Fmusicgen-).
|
||||
|
||||
## Generation
|
||||
|
||||
MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly
|
||||
better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,
|
||||
and can be explicitly specified by setting `do_sample=True` in the call to [`MusicgenForConditionalGeneration.generate`],
|
||||
or by overriding the model's generation config (see below).
|
||||
|
||||
### Unconditional Generation
|
||||
|
||||
The inputs for unconditional (or 'null') generation can be obtained through the method
|
||||
[`MusicgenForConditionalGeneration.get_unconditional_inputs`]:
|
||||
|
||||
```python
|
||||
>>> from transformers import MusicgenForConditionalGeneration
|
||||
|
||||
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
|
||||
|
||||
>>> audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
|
||||
```
|
||||
|
||||
The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen
|
||||
to the generated audio samples, you can either play them in an ipynb notebook:
|
||||
|
||||
```python
|
||||
from IPython.display import Audio
|
||||
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
Audio(audio_values[0].numpy(), rate=sampling_rate)
|
||||
```
|
||||
|
||||
Or save them as a `.wav` file using a third-party library, e.g. `scipy`:
|
||||
|
||||
```python
|
||||
>>> import scipy
|
||||
|
||||
>>> sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
>>> scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
|
||||
```
|
||||
|
||||
### Text-Conditional Generation
|
||||
|
||||
The model can generate an audio sample conditioned on a text prompt through use of the [`MusicgenProcessor`] to pre-process
|
||||
the inputs:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
|
||||
... padding=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
|
||||
```
|
||||
|
||||
The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits
|
||||
(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or
|
||||
'null' prompt). Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,
|
||||
use `guidance_scale=3` (default).
|
||||
|
||||
### Audio-Prompted Generation
|
||||
|
||||
The same [`MusicgenProcessor`] can be used to pre-process an audio prompt that is used for audio continuation. In the
|
||||
following example, we load an audio file using the 🤗 Datasets library, which can be pip installed through the command
|
||||
below:
|
||||
|
||||
```
|
||||
pip install --upgrade pip
|
||||
pip install datasets[audio]
|
||||
```
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True)
|
||||
>>> sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # take the first half of the audio sample
|
||||
>>> sample["array"] = sample["array"][: len(sample["array"]) // 2]
|
||||
|
||||
>>> inputs = processor(
|
||||
... audio=sample["array"],
|
||||
... sampling_rate=sample["sampling_rate"],
|
||||
... text=["80s blues track with groovy saxophone"],
|
||||
... padding=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
|
||||
```
|
||||
|
||||
For batched audio-prompted generation, the generated `audio_values` can be post-processed to remove padding by using the
|
||||
[`MusicgenProcessor`] class:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
>>> dataset = load_dataset("sanchit-gandhi/gtzan", split="train", streaming=True)
|
||||
>>> sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # take the first quarter of the audio sample
|
||||
>>> sample_1 = sample["array"][: len(sample["array"]) // 4]
|
||||
|
||||
>>> # take the first half of the audio sample
|
||||
>>> sample_2 = sample["array"][: len(sample["array"]) // 2]
|
||||
|
||||
>>> inputs = processor(
|
||||
... audio=[sample_1, sample_2],
|
||||
... sampling_rate=sample["sampling_rate"],
|
||||
... text=["80s blues track with groovy saxophone", "90s rock song with loud guitars and heavy drums"],
|
||||
... padding=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
>>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
|
||||
|
||||
>>> # post-process to remove padding from the batched audio
|
||||
>>> audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)
|
||||
```
|
||||
|
||||
### Generation Configuration
|
||||
|
||||
The default parameters that control the generation process, such as sampling, guidance scale and number of generated
|
||||
tokens, can be found in the model's generation config, and updated as desired:
|
||||
|
||||
```python
|
||||
>>> from transformers import MusicgenForConditionalGeneration
|
||||
|
||||
>>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
>>> # inspect the default generation config
|
||||
>>> model.generation_config
|
||||
|
||||
>>> # increase the guidance scale to 4.0
|
||||
>>> model.generation_config.guidance_scale = 4.0
|
||||
|
||||
>>> # decrease the max length to 256 tokens
|
||||
>>> model.generation_config.max_length = 256
|
||||
```
|
||||
|
||||
Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting
|
||||
`do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the
|
||||
generation config.
|
||||
|
||||
## Model Structure
|
||||
|
||||
The MusicGen model can be de-composed into three distinct stages:
|
||||
1. Text encoder: maps the text inputs to a sequence of hidden-state representations. The pre-trained MusicGen models use a frozen text encoder from either T5 or Flan-T5
|
||||
2. MusicGen decoder: a language model (LM) that auto-regressively generates audio tokens (or codes) conditional on the encoder hidden-state representations
|
||||
3. Audio encoder/decoder: used to encode an audio prompt to use as prompt tokens, and recover the audio waveform from the audio tokens predicted by the decoder
|
||||
|
||||
Thus, the MusicGen model can either be used as a standalone decoder model, corresponding to the class [`MusicgenForCausalLM`],
|
||||
or as a composite model that includes the text encoder and audio encoder/decoder, corresponding to the class
|
||||
[`MusicgenForConditionalGeneration`].
|
||||
|
||||
Since the text encoder and audio encoder/decoder models are frozen during training, the MusicGen decoder [`MusicgenForCausalLM`]
|
||||
can be trained standalone on a dataset of encoder hidden-states and audio codes. For inference, the trained decoder can
|
||||
be combined with the frozen text encoder and audio encoder/decoders to recover the composite [`MusicgenForConditionalGeneration`]
|
||||
model.
|
||||
|
||||
Below, we demonstrate how to construct the composite [`MusicgenForConditionalGeneration`] model from its three constituent
|
||||
parts, as would typically be done following training of the MusicGen decoder LM:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoConfig, AutoModelForTextEncoding, AutoModel, MusicgenForCausalLM, MusicgenForConditionalGeneration
|
||||
|
||||
>>> text_encoder = AutoModelForTextEncoding.from_pretrained("t5-base")
|
||||
>>> audio_encoder = AutoModel.from_pretrained("facebook/encodec_32khz")
|
||||
>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder
|
||||
>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config)
|
||||
|
||||
>>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(text_encoder, audio_encoder, decoder)
|
||||
```
|
||||
|
||||
If only the decoder needs to be loaded from the pre-trained checkpoint for the composite model, it can be loaded by first
|
||||
specifying the correct config, or be accessed through the `.decoder` attribute of the composite model:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoConfig, MusicgenForCausalLM, MusicgenForConditionalGeneration
|
||||
|
||||
>>> # Option 1: get decoder config and pass to `.from_pretrained`
|
||||
>>> decoder_config = AutoConfig.from_pretrained("facebook/musicgen-small").decoder
|
||||
>>> decoder = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small", **decoder_config)
|
||||
|
||||
>>> # Option 2: load the entire composite model, but only return the decoder
|
||||
>>> decoder = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").decoder
|
||||
```
|
||||
|
||||
Tips:
|
||||
* MusicGen is trained on the 32kHz checkpoint of Encodec. You should ensure you use a compatible version of the Encodec model.
|
||||
* Sampling mode tends to deliver better results than greedy - you can toggle sampling with the variable `do_sample` in the call to [`MusicgenForConditionalGeneration.generate`]
|
||||
|
||||
## MusicgenDecoderConfig
|
||||
|
||||
[[autodoc]] MusicgenDecoderConfig
|
||||
|
||||
## MusicgenConfig
|
||||
|
||||
[[autodoc]] MusicgenConfig
|
||||
|
||||
## MusicgenProcessor
|
||||
|
||||
[[autodoc]] MusicgenProcessor
|
||||
|
||||
## MusicgenModel
|
||||
|
||||
[[autodoc]] MusicgenModel
|
||||
- forward
|
||||
|
||||
## MusicgenForCausalLM
|
||||
|
||||
[[autodoc]] MusicgenForCausalLM
|
||||
- forward
|
||||
|
||||
## MusicgenForConditionalGeneration
|
||||
|
||||
[[autodoc]] MusicgenForConditionalGeneration
|
||||
- forward
|
||||
@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
|
||||
Choose one of the following architectures:
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
@@ -402,6 +402,11 @@ _import_structure = {
|
||||
"models.mobilevitv2": ["MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTV2Config"],
|
||||
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"],
|
||||
"models.mt5": ["MT5Config"],
|
||||
"models.musicgen": [
|
||||
"MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"MusicgenConfig",
|
||||
"MusicgenDecoderConfig",
|
||||
],
|
||||
"models.mvp": ["MvpConfig", "MvpTokenizer"],
|
||||
"models.nat": ["NAT_PRETRAINED_CONFIG_ARCHIVE_MAP", "NatConfig"],
|
||||
"models.nezha": ["NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP", "NezhaConfig"],
|
||||
@@ -2134,6 +2139,16 @@ else:
|
||||
_import_structure["models.mt5"].extend(
|
||||
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.musicgen"].extend(
|
||||
[
|
||||
"MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"MusicgenForCausalLM",
|
||||
"MusicgenForConditionalGeneration",
|
||||
"MusicgenModel",
|
||||
"MusicgenPreTrainedModel",
|
||||
"MusicgenProcessor",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mvp"].extend(
|
||||
[
|
||||
"MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -4249,6 +4264,11 @@ if TYPE_CHECKING:
|
||||
from .models.mobilevitv2 import MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTV2Config
|
||||
from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer
|
||||
from .models.mt5 import MT5Config
|
||||
from .models.musicgen import (
|
||||
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MusicgenConfig,
|
||||
MusicgenDecoderConfig,
|
||||
)
|
||||
from .models.mvp import MvpConfig, MvpTokenizer
|
||||
from .models.nat import NAT_PRETRAINED_CONFIG_ARCHIVE_MAP, NatConfig
|
||||
from .models.nezha import NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP, NezhaConfig
|
||||
@@ -5709,6 +5729,14 @@ if TYPE_CHECKING:
|
||||
MT5Model,
|
||||
MT5PreTrainedModel,
|
||||
)
|
||||
from .models.musicgen import (
|
||||
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MusicgenForCausalLM,
|
||||
MusicgenForConditionalGeneration,
|
||||
MusicgenModel,
|
||||
MusicgenPreTrainedModel,
|
||||
MusicgenProcessor,
|
||||
)
|
||||
from .models.mvp import (
|
||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MvpForCausalLM,
|
||||
|
||||
@@ -185,6 +185,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||
sequence being selected, while negative biases do the opposite. Check
|
||||
[`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples.
|
||||
guidance_scale (`float`, *optional*):
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality.
|
||||
|
||||
> Parameters that define the output variables of `generate`
|
||||
|
||||
@@ -265,6 +269,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||
|
||||
# Parameters that define the output variables of `generate`
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
|
||||
@@ -1065,3 +1065,40 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
scores[k, : self.timestamp_begin] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||
r"""Logits processor for classifier free guidance (CFG). The scores are split over the batch dimension,
|
||||
where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
|
||||
correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
|
||||
weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
|
||||
|
||||
Args:
|
||||
guidance_scale (float):
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality.
|
||||
"""
|
||||
|
||||
def __init__(self, guidance_scale):
|
||||
if guidance_scale > 1:
|
||||
self.guidance_scale = guidance_scale
|
||||
else:
|
||||
raise ValueError(
|
||||
"Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
|
||||
f"{guidance_scale}."
|
||||
)
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# simple check to make sure we have compatible batch sizes between our
|
||||
# logits scores (cond + uncond) and input ids (cond only)
|
||||
if scores.shape[0] != 2 * input_ids.shape[0]:
|
||||
raise ValueError(
|
||||
f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
|
||||
f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
|
||||
f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
|
||||
)
|
||||
unguided_bsz = scores.shape[0] // 2
|
||||
cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
|
||||
scores = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
|
||||
return scores
|
||||
|
||||
@@ -38,6 +38,7 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .logits_process import (
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
@@ -940,6 +941,8 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||
processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
|
||||
@@ -135,6 +135,7 @@ from . import (
|
||||
mobilevitv2,
|
||||
mpnet,
|
||||
mt5,
|
||||
musicgen,
|
||||
mvp,
|
||||
nat,
|
||||
nezha,
|
||||
|
||||
@@ -138,6 +138,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevitv2", "MobileViTV2Config"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mt5", "MT5Config"),
|
||||
("musicgen", "MusicgenConfig"),
|
||||
("mvp", "MvpConfig"),
|
||||
("nat", "NatConfig"),
|
||||
("nezha", "NezhaConfig"),
|
||||
@@ -327,6 +328,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mobilevitv2", "MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("musicgen", "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mvp", "MVP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nat", "NAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nezha", "NEZHA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@@ -532,6 +534,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mobilevitv2", "MobileViTV2"),
|
||||
("mpnet", "MPNet"),
|
||||
("mt5", "MT5"),
|
||||
("musicgen", "MusicGen"),
|
||||
("mvp", "MVP"),
|
||||
("nat", "NAT"),
|
||||
("nezha", "Nezha"),
|
||||
|
||||
@@ -389,6 +389,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("mbart", "MBartForCausalLM"),
|
||||
("mega", "MegaForCausalLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("musicgen", "MusicgenForCausalLM"),
|
||||
("mvp", "MvpForCausalLM"),
|
||||
("open-llama", "OpenLlamaForCausalLM"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
|
||||
67
src/transformers/models/musicgen/__init__.py
Normal file
67
src/transformers/models/musicgen/__init__.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_musicgen": [
|
||||
"MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"MusicgenConfig",
|
||||
"MusicgenDecoderConfig",
|
||||
],
|
||||
"processing_musicgen": ["MusicgenProcessor"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_musicgen"] = [
|
||||
"MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"MusicgenForConditionalGeneration",
|
||||
"MusicgenForCausalLM",
|
||||
"MusicgenModel",
|
||||
"MusicgenPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_musicgen import (
|
||||
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MusicgenConfig,
|
||||
MusicgenDecoderConfig,
|
||||
)
|
||||
from .processing_musicgen import MusicgenProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_musicgen import (
|
||||
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MusicgenForCausalLM,
|
||||
MusicgenForConditionalGeneration,
|
||||
MusicgenModel,
|
||||
MusicgenPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
243
src/transformers/models/musicgen/configuration_musicgen.py
Normal file
243
src/transformers/models/musicgen/configuration_musicgen.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MusicGen model configuration"""
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/musicgen-small": "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json",
|
||||
# See all Musicgen models at https://huggingface.co/models?filter=musicgen
|
||||
}
|
||||
|
||||
|
||||
class MusicgenDecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
|
||||
MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the MusicGen
|
||||
[facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 2048):
|
||||
Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
|
||||
represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of decoder layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer block.
|
||||
ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
|
||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically, set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_factor (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(hidden_size).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model should return the last key/values attentions (not used by all models)
|
||||
num_codebooks (`int`, *optional*, defaults to 4):
|
||||
The number of parallel codebooks forwarded to the model.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether input and output word embeddings should be tied.
|
||||
"""
|
||||
model_type = "musicgen_decoder"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=2048,
|
||||
max_position_embeddings=2048,
|
||||
num_hidden_layers=24,
|
||||
ffn_dim=4096,
|
||||
num_attention_heads=16,
|
||||
layerdrop=0.0,
|
||||
use_cache=True,
|
||||
activation_function="gelu",
|
||||
hidden_size=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
initializer_factor=0.02,
|
||||
scale_embedding=False,
|
||||
num_codebooks=4,
|
||||
pad_token_id=2048,
|
||||
bos_token_id=2048,
|
||||
eos_token_id=None,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.initializer_factor = initializer_factor
|
||||
self.layerdrop = layerdrop
|
||||
self.use_cache = use_cache
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.num_codebooks = num_codebooks
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MusicgenConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
|
||||
MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
|
||||
configs.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments. Notably:
|
||||
|
||||
- **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the text encoder config.
|
||||
- **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
||||
defines the audio encoder config.
|
||||
- **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
||||
the decoder config.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... MusicgenConfig,
|
||||
... MusicgenDecoderConfig,
|
||||
... T5Config,
|
||||
... EncodecConfig,
|
||||
... MusicgenForConditionalGeneration,
|
||||
... )
|
||||
|
||||
>>> # Initializing text encoder, audio encoder, and decoder model configurations
|
||||
>>> text_encoder_config = T5Config()
|
||||
>>> audio_encoder_config = EncodecConfig()
|
||||
>>> decoder_config = MusicgenDecoderConfig()
|
||||
|
||||
>>> configuration = MusicgenConfig.from_sub_models_config(
|
||||
... text_encoder_config, audio_encoder_config, decoder_config
|
||||
... )
|
||||
|
||||
>>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
|
||||
>>> model = MusicgenForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
>>> config_text_encoder = model.config.text_encoder
|
||||
>>> config_audio_encoder = model.config.audio_encoder
|
||||
>>> config_decoder = model.config.decoder
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained("musicgen-model")
|
||||
|
||||
>>> # loading model and config from pretrained folder
|
||||
>>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
|
||||
>>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
|
||||
```"""
|
||||
|
||||
model_type = "musicgen"
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
|
||||
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
|
||||
|
||||
text_encoder_config = kwargs.pop("text_encoder")
|
||||
text_encoder_model_type = text_encoder_config.pop("model_type")
|
||||
|
||||
audio_encoder_config = kwargs.pop("audio_encoder")
|
||||
audio_encoder_model_type = audio_encoder_config.pop("model_type")
|
||||
|
||||
decoder_config = kwargs.pop("decoder")
|
||||
|
||||
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
|
||||
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
|
||||
self.decoder = MusicgenDecoderConfig(**decoder_config)
|
||||
self.is_encoder_decoder = True
|
||||
|
||||
@classmethod
|
||||
def from_sub_models_config(
|
||||
cls,
|
||||
text_encoder_config: PretrainedConfig,
|
||||
audio_encoder_config: PretrainedConfig,
|
||||
decoder_config: MusicgenDecoderConfig,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
|
||||
configurations.
|
||||
|
||||
Returns:
|
||||
[`MusicgenConfig`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(
|
||||
text_encoder=text_encoder_config.to_dict(),
|
||||
audio_encoder=audio_encoder_config.to_dict(),
|
||||
decoder=decoder_config.to_dict(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["text_encoder"] = self.text_encoder.to_dict()
|
||||
output["audio_encoder"] = self.audio_encoder.to_dict()
|
||||
output["decoder"] = self.decoder.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
@@ -0,0 +1,209 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert MusicGen checkpoints from the original repository."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
EncodecModel,
|
||||
MusicgenDecoderConfig,
|
||||
MusicgenForConditionalGeneration,
|
||||
MusicgenProcessor,
|
||||
T5EncoderModel,
|
||||
)
|
||||
from transformers.models.musicgen.modeling_musicgen import MusicgenForCausalLM
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
EXPECTED_MISSING_KEYS = ["model.decoder.embed_positions.weights"]
|
||||
|
||||
|
||||
def rename_keys(name):
|
||||
if "emb" in name:
|
||||
name = name.replace("emb", "model.decoder.embed_tokens")
|
||||
if "transformer" in name:
|
||||
name = name.replace("transformer", "model.decoder")
|
||||
if "cross_attention" in name:
|
||||
name = name.replace("cross_attention", "encoder_attn")
|
||||
if "linear1" in name:
|
||||
name = name.replace("linear1", "fc1")
|
||||
if "linear2" in name:
|
||||
name = name.replace("linear2", "fc2")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "self_attn_layer_norm")
|
||||
if "norm_cross" in name:
|
||||
name = name.replace("norm_cross", "encoder_attn_layer_norm")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "final_layer_norm")
|
||||
if "out_norm" in name:
|
||||
name = name.replace("out_norm", "model.decoder.layer_norm")
|
||||
if "linears" in name:
|
||||
name = name.replace("linears", "lm_heads")
|
||||
if "condition_provider.conditioners.description.output_proj" in name:
|
||||
name = name.replace("condition_provider.conditioners.description.output_proj", "enc_to_dec_proj")
|
||||
return name
|
||||
|
||||
|
||||
def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, Dict]:
|
||||
"""Function that takes the fairseq Musicgen state dict and renames it according to the HF
|
||||
module names. It further partitions the state dict into the decoder (LM) state dict, and that for the
|
||||
encoder-decoder projection."""
|
||||
keys = list(state_dict.keys())
|
||||
enc_dec_proj_state_dict = {}
|
||||
for key in keys:
|
||||
val = state_dict.pop(key)
|
||||
key = rename_keys(key)
|
||||
if "in_proj_weight" in key:
|
||||
# split fused qkv proj
|
||||
state_dict[key.replace("in_proj_weight", "q_proj.weight")] = val[:hidden_size, :]
|
||||
state_dict[key.replace("in_proj_weight", "k_proj.weight")] = val[hidden_size : 2 * hidden_size, :]
|
||||
state_dict[key.replace("in_proj_weight", "v_proj.weight")] = val[-hidden_size:, :]
|
||||
elif "enc_to_dec_proj" in key:
|
||||
enc_dec_proj_state_dict[key[len("enc_to_dec_proj.") :]] = val
|
||||
else:
|
||||
state_dict[key] = val
|
||||
return state_dict, enc_dec_proj_state_dict
|
||||
|
||||
|
||||
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
|
||||
if checkpoint == "small":
|
||||
# default config values
|
||||
hidden_size = 1024
|
||||
num_hidden_layers = 24
|
||||
num_attention_heads = 16
|
||||
elif checkpoint == "medium":
|
||||
hidden_size = 1536
|
||||
num_hidden_layers = 48
|
||||
num_attention_heads = 24
|
||||
elif checkpoint == "large":
|
||||
hidden_size = 2048
|
||||
num_hidden_layers = 48
|
||||
num_attention_heads = 32
|
||||
else:
|
||||
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.")
|
||||
config = MusicgenDecoderConfig(
|
||||
hidden_size=hidden_size,
|
||||
ffn_dim=hidden_size * 4,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"):
|
||||
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
|
||||
decoder_config = decoder_config_from_checkpoint(checkpoint)
|
||||
|
||||
decoder_state_dict = fairseq_model.lm.state_dict()
|
||||
decoder_state_dict, enc_dec_proj_state_dict = rename_state_dict(
|
||||
decoder_state_dict, hidden_size=decoder_config.hidden_size
|
||||
)
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained("t5-base")
|
||||
audio_encoder = EncodecModel.from_pretrained("facebook/encodec_32khz")
|
||||
decoder = MusicgenForCausalLM(decoder_config).eval()
|
||||
|
||||
# load all decoder weights - expect that we'll be missing embeddings and enc-dec projection
|
||||
missing_keys, unexpected_keys = decoder.load_state_dict(decoder_state_dict, strict=False)
|
||||
|
||||
for key in missing_keys.copy():
|
||||
if key.startswith(("text_encoder", "audio_encoder")) or key in EXPECTED_MISSING_KEYS:
|
||||
missing_keys.remove(key)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(f"Missing key(s) in state_dict: {missing_keys}")
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
raise ValueError(f"Unexpected key(s) in state_dict: {unexpected_keys}")
|
||||
|
||||
# init the composite model
|
||||
model = MusicgenForConditionalGeneration(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder)
|
||||
|
||||
# load the pre-trained enc-dec projection (from the decoder state dict)
|
||||
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
|
||||
|
||||
# check we can do a forward pass
|
||||
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1)
|
||||
decoder_input_ids = input_ids.reshape(2 * 4, -1)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||
|
||||
if logits.shape != (8, 1, 2048):
|
||||
raise ValueError("Incorrect shape for logits")
|
||||
|
||||
# now construct the processor
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left")
|
||||
|
||||
processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# set the appropriate bos/pad token ids
|
||||
model.generation_config.decoder_start_token_id = 2048
|
||||
model.generation_config.pad_token_id = 2048
|
||||
|
||||
# set other default generation config params
|
||||
model.generation_config.max_length = int(30 * audio_encoder.config.frame_rate)
|
||||
model.generation_config.do_sample = True
|
||||
model.generation_config.guidance_scale = 3.0
|
||||
|
||||
if pytorch_dump_folder is not None:
|
||||
Path(pytorch_dump_folder).mkdir(exist_ok=True)
|
||||
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
|
||||
model.save_pretrained(pytorch_dump_folder)
|
||||
processor.save_pretrained(pytorch_dump_folder)
|
||||
|
||||
if repo_id:
|
||||
logger.info(f"Pushing model {checkpoint} to {repo_id}")
|
||||
model.push_to_hub(repo_id)
|
||||
processor.push_to_hub(repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
default="small",
|
||||
type=str,
|
||||
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder",
|
||||
required=True,
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
|
||||
2512
src/transformers/models/musicgen/modeling_musicgen.py
Normal file
2512
src/transformers/models/musicgen/modeling_musicgen.py
Normal file
File diff suppressed because it is too large
Load Diff
139
src/transformers/models/musicgen/processing_musicgen.py
Normal file
139
src/transformers/models/musicgen/processing_musicgen.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Text/audio processor class for MusicGen
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils import to_numpy
|
||||
|
||||
|
||||
class MusicgenProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
|
||||
class.
|
||||
|
||||
[`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
|
||||
[`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor (`EncodecFeatureExtractor`):
|
||||
An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
An instance of [`T5Tokenizer`]. The tokenizer is a required input.
|
||||
"""
|
||||
feature_extractor_class = "EncodecFeatureExtractor"
|
||||
tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
||||
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
|
||||
argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
|
||||
information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
audio = kwargs.pop("audio", None)
|
||||
sampling_rate = kwargs.pop("sampling_rate", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if text is not None:
|
||||
inputs = self.tokenizer(text, **kwargs)
|
||||
|
||||
if audio is not None:
|
||||
audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
|
||||
|
||||
if audio is None:
|
||||
return inputs
|
||||
|
||||
elif text is None:
|
||||
return audio_inputs
|
||||
|
||||
else:
|
||||
inputs["input_values"] = audio_inputs["input_values"]
|
||||
if "padding_mask" in audio_inputs:
|
||||
inputs["padding_mask"] = audio_inputs["padding_mask"]
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
|
||||
from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
|
||||
[`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
audio_values = kwargs.pop("audio", None)
|
||||
padding_mask = kwargs.pop("padding_mask", None)
|
||||
|
||||
if len(args) > 0:
|
||||
audio_values = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio_values is not None:
|
||||
return self._decode_audio(audio_values, padding_mask=padding_mask)
|
||||
else:
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
|
||||
docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
|
||||
"""
|
||||
This method strips any padding from the audio values to return a list of numpy audio arrays.
|
||||
"""
|
||||
audio_values = to_numpy(audio_values)
|
||||
bsz, channels, seq_len = audio_values.shape
|
||||
|
||||
if padding_mask is None:
|
||||
return list(audio_values)
|
||||
|
||||
padding_mask = to_numpy(padding_mask)
|
||||
|
||||
# match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
|
||||
# token (so that the generated audio values are **not** treated as padded tokens)
|
||||
difference = seq_len - padding_mask.shape[-1]
|
||||
padding_value = 1 - self.feature_extractor.padding_value
|
||||
padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
|
||||
|
||||
audio_values = audio_values.tolist()
|
||||
for i in range(bsz):
|
||||
sliced_audio = np.asarray(audio_values[i])[
|
||||
padding_mask[i][None, :] != self.feature_extractor.padding_value
|
||||
]
|
||||
audio_values[i] = sliced_audio.reshape(channels, -1)
|
||||
|
||||
return audio_values
|
||||
@@ -4936,6 +4936,44 @@ class MT5PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class MusicgenForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MusicgenForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MusicgenModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MusicgenPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MusicgenProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
0
tests/models/musicgen/__init__.py
Normal file
0
tests/models/musicgen/__init__.py
Normal file
1346
tests/models/musicgen/test_modeling_musicgen.py
Normal file
1346
tests/models/musicgen/test_modeling_musicgen.py
Normal file
File diff suppressed because it is too large
Load Diff
173
tests/models/musicgen/test_processing_musicgen.py
Normal file
173
tests/models/musicgen/test_processing_musicgen.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the MusicGen processor."""
|
||||
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import T5Tokenizer, T5TokenizerFast
|
||||
from transformers.testing_utils import require_sentencepiece, require_torch
|
||||
from transformers.utils.import_utils import is_speech_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import EncodecFeatureExtractor, MusicgenProcessor
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
class MusicgenProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.checkpoint = "facebook/musicgen-small"
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return T5Tokenizer.from_pretrained(self.checkpoint, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return EncodecFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = MusicgenProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = MusicgenProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = MusicgenProcessor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, T5TokenizerFast)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, EncodecFeatureExtractor)
|
||||
|
||||
def test_feature_extractor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
||||
input_processor = processor(raw_speech, return_tensors="np")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
|
||||
decoded_processor = processor.batch_decode(sequences=predicted_ids)
|
||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
||||
|
||||
def test_model_input_names(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
self.assertListEqual(
|
||||
processor.model_input_names,
|
||||
feature_extractor.model_input_names,
|
||||
msg="`processor` and `feature_extractor` model input names do not match",
|
||||
)
|
||||
|
||||
def test_decode_audio(self):
|
||||
feature_extractor = self.get_feature_extractor(padding_side="left")
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = MusicgenProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
raw_speech = [floats_list((1, x))[0] for x in range(5, 20, 5)]
|
||||
padding_mask = processor(raw_speech).padding_mask
|
||||
|
||||
generated_speech = np.asarray(floats_list((3, 20)))[:, None, :]
|
||||
decoded_audios = processor.batch_decode(generated_speech, padding_mask=padding_mask)
|
||||
|
||||
self.assertIsInstance(decoded_audios, list)
|
||||
|
||||
for audio in decoded_audios:
|
||||
self.assertIsInstance(audio, np.ndarray)
|
||||
|
||||
self.assertTrue(decoded_audios[0].shape == (1, 10))
|
||||
self.assertTrue(decoded_audios[1].shape == (1, 15))
|
||||
self.assertTrue(decoded_audios[2].shape == (1, 20))
|
||||
@@ -37,6 +37,7 @@ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
|
||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"DecisionTransformerConfig",
|
||||
"EncoderDecoderConfig",
|
||||
"MusicgenConfig",
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"TimmBackboneConfig",
|
||||
|
||||
@@ -119,6 +119,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"MegatronBertEncoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MusicgenDecoder", # Building part of bigger (tested) model.
|
||||
"MvpDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MvpEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||
@@ -331,6 +332,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"SpeechT5ForSpeechToSpeech",
|
||||
"SpeechT5ForTextToSpeech",
|
||||
"SpeechT5HifiGan",
|
||||
"MusicgenModel",
|
||||
"MusicgenForConditionalGeneration",
|
||||
]
|
||||
|
||||
# Update this list for models that have multiple model types for the same
|
||||
|
||||
@@ -134,6 +134,9 @@ src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py
|
||||
src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py
|
||||
src/transformers/models/mobilevit/modeling_mobilevit.py
|
||||
src/transformers/models/mobilevit/modeling_tf_mobilevit.py
|
||||
src/transformers/models/musicgen/modeling_musicgen.py
|
||||
src/transformers/models/musicgen/configuration_musicgen.py
|
||||
src/transformers/models/musicgen/processing_musicgen.py
|
||||
src/transformers/models/mvp/configuration_mvp.py
|
||||
src/transformers/models/nat/configuration_nat.py
|
||||
src/transformers/models/nat/modeling_nat.py
|
||||
|
||||
Reference in New Issue
Block a user