Add Multi Resolution Analysis (MRA) (New PR) (#24513)
* Add all files * Update masked_language_modeling.md * fix mlm models * fix conflicts * fix conflicts * fix copies * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Reduce seq_len and hidden_size in ModelTester * remove output_attentions * fix conflicts * remove copied from statements * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -410,6 +410,7 @@ Current number of checkpoints: ** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (from the University of Wisconsin - Madison) released with the paper [Multi Resolution Analysis (MRA) for Approximate Self-Attention](https://arxiv.org/abs/2207.10284) by Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
|
||||
@@ -385,6 +385,7 @@ Número actual de puntos de control: ** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (from the University of Wisconsin - Madison) released with the paper [Multi Resolution Analysis (MRA) by Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
|
||||
@@ -357,6 +357,7 @@ conda install -c huggingface transformers
|
||||
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (Apple से) साथ में कागज [MobileViT: लाइट-वेट, जनरल-पर्पस, और मोबाइल-फ्रेंडली विजन ट्रांसफॉर्मर] (https://arxiv.org/abs/2110.02178) सचिन मेहता और मोहम्मद रस्तगरी द्वारा पोस्ट किया गया।
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple से) Sachin Mehta and Mohammad Rastegari. द्वाराअनुसंधान पत्र [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) के साथ जारी किया गया
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (the University of Wisconsin - Madison से) Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh. द्वाराअनुसंधान पत्र [Multi Resolution Analysis (MRA) के साथ जारी किया गया
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI से) साथ वाला पेपर [mT5: एक व्यापक बहुभाषी पूर्व-प्रशिक्षित टेक्स्ट-टू-टेक्स्ट ट्रांसफॉर्मर]( https://arxiv.org/abs/2010.11934) लिंटिंग ज़ू, नोआ कॉन्सटेंट, एडम रॉबर्ट्स, मिहिर काले, रामी अल-रफू, आदित्य सिद्धांत, आदित्य बरुआ, कॉलिन रैफेल द्वारा पोस्ट किया गया।
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
|
||||
@@ -419,6 +419,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
|
||||
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (Apple から) Sachin Mehta and Mohammad Rastegari から公開された研究論文: [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178)
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple から) Sachin Mehta and Mohammad Rastegari. から公開された研究論文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research から) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu から公開された研究論文: [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297)
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (the University of Wisconsin - Madison から) Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh. から公開された研究論文 [Multi Resolution Analysis (MRA)
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI から) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel から公開された研究論文: [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934)
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box から) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen から公開された研究論文: [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131)
|
||||
|
||||
@@ -334,6 +334,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (Apple 에서) Sachin Mehta and Mohammad Rastegari 의 [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) 논문과 함께 발표했습니다.
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (Apple 에서 제공)은 Sachin Mehta and Mohammad Rastegari.의 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680)논문과 함께 발표했습니다.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (Microsoft Research 에서) Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 의 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 논문과 함께 발표했습니다.
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (the University of Wisconsin - Madison 에서 제공)은 Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh.의 [Multi Resolution Analysis (MRA)논문과 함께 발표했습니다.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (Google AI 에서) Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 의 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 논문과 함께 발표했습니다.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (RUC AI Box 에서) Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 의 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 논문과 함께 발표했습니다.
|
||||
|
||||
@@ -358,6 +358,7 @@ conda install -c huggingface transformers
|
||||
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (来自 Apple) 伴随论文 [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) 由 Sachin Mehta and Mohammad Rastegari 发布。
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (来自 Apple) 伴随论文 [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) 由 Sachin Mehta and Mohammad Rastegari 发布。
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (来自 the University of Wisconsin - Madison) 伴随论文 [Multi Resolution Analysis (MRA) 由 Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh 发布。
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (来自 中国人民大学 AI Box) 伴随论文 [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) 由 Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen 发布。
|
||||
|
||||
@@ -370,6 +370,7 @@ conda install -c huggingface transformers
|
||||
1. **[MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MobileViTV2](https://huggingface.co/docs/transformers/model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MRA](https://huggingface.co/docs/transformers/main/model_doc/mra)** (from the University of Wisconsin - Madison) released with the paper [Multi Resolution Analysis (MRA) by Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](https://huggingface.co/docs/transformers/main/model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](https://huggingface.co/docs/transformers/model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
|
||||
@@ -361,6 +361,8 @@
|
||||
title: MobileBERT
|
||||
- local: model_doc/mpnet
|
||||
title: MPNet
|
||||
- local: model_doc/mra
|
||||
title: MRA
|
||||
- local: model_doc/mt5
|
||||
title: MT5
|
||||
- local: model_doc/mvp
|
||||
|
||||
@@ -174,6 +174,7 @@ The documentation is organized into five sections:
|
||||
1. **[MobileViT](model_doc/mobilevit)** (from Apple) released with the paper [MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer](https://arxiv.org/abs/2110.02178) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MobileViTV2](model_doc/mobilevitv2)** (from Apple) released with the paper [Separable Self-attention for Mobile Vision Transformers](https://arxiv.org/abs/2206.02680) by Sachin Mehta and Mohammad Rastegari.
|
||||
1. **[MPNet](model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MRA](model_doc/mra)** (from the University of Wisconsin - Madison) released with the paper [Multi Resolution Analysis (MRA) for Approximate Self-Attention](https://arxiv.org/abs/2207.10284) by Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, Vikas Singh.
|
||||
1. **[MT5](model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[MusicGen](model_doc/musicgen)** (from Meta) released with the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi and Alexandre Défossez.
|
||||
1. **[MVP](model_doc/mvp)** (from RUC AI Box) released with the paper [MVP: Multi-task Supervised Pre-training for Natural Language Generation](https://arxiv.org/abs/2206.12131) by Tianyi Tang, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
||||
@@ -381,6 +382,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| MobileViT | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| MobileViTV2 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MPNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| MRA | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MT5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| MusicGen | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MVP | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
||||
68
docs/source/en/model_doc/mra.md
Normal file
68
docs/source/en/model_doc/mra.md
Normal file
@@ -0,0 +1,68 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# MRA
|
||||
|
||||
## Overview
|
||||
|
||||
The MRA model was proposed in [Multi Resolution Analysis (MRA) for Approximate Self-Attention](https://arxiv.org/abs/2207.10284) by Zhanpeng Zeng, Sourav Pal, Jeffery Kline, Glenn M Fung, and Vikas Singh.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformers have emerged as a preferred model for many tasks in natural langugage processing and vision. Recent efforts on training and deploying Transformers more efficiently have identified many strategies to approximate the self-attention matrix, a key module in a Transformer architecture. Effective ideas include various prespecified sparsity patterns, low-rank basis expansions and combinations thereof. In this paper, we revisit classical Multiresolution Analysis (MRA) concepts such as Wavelets, whose potential value in this setting remains underexplored thus far. We show that simple approximations based on empirical feedback and design choices informed by modern hardware and implementation challenges, eventually yield a MRA-based approach for self-attention with an excellent performance profile across most criteria of interest. We undertake an extensive set of experiments and demonstrate that this multi-resolution scheme outperforms most efficient self-attention proposals and is favorable for both short and long sequences. Code is available at https://github.com/mlpen/mra-attention.*
|
||||
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03).
|
||||
The original code can be found [here](https://github.com/mlpen/mra-attention).
|
||||
|
||||
|
||||
## MraConfig
|
||||
|
||||
[[autodoc]] MraConfig
|
||||
|
||||
|
||||
## MraModel
|
||||
|
||||
[[autodoc]] MraModel
|
||||
- forward
|
||||
|
||||
|
||||
## MraForMaskedLM
|
||||
|
||||
[[autodoc]] MraForMaskedLM
|
||||
- forward
|
||||
|
||||
|
||||
## MraForSequenceClassification
|
||||
|
||||
[[autodoc]] MraForSequenceClassification
|
||||
- forward
|
||||
|
||||
## MraForMultipleChoice
|
||||
|
||||
[[autodoc]] MraForMultipleChoice
|
||||
- forward
|
||||
|
||||
|
||||
## MraForTokenClassification
|
||||
|
||||
[[autodoc]] MraForTokenClassification
|
||||
- forward
|
||||
|
||||
|
||||
## MraForQuestionAnswering
|
||||
|
||||
[[autodoc]] MraForQuestionAnswering
|
||||
- forward
|
||||
@@ -35,7 +35,7 @@ Choose one of the following architectures:
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [CamemBERT](../model_doc/camembert), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Perceiver](../model_doc/perceiver), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Wav2Vec2](../model_doc/wav2vec2), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [CamemBERT](../model_doc/camembert), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Perceiver](../model_doc/perceiver), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Wav2Vec2](../model_doc/wav2vec2), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [I-BERT](../model_doc/ibert), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [I-BERT](../model_doc/ibert), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
@@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
@@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
|
||||
@@ -402,6 +402,7 @@ _import_structure = {
|
||||
"models.mobilevit": ["MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTConfig"],
|
||||
"models.mobilevitv2": ["MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileViTV2Config"],
|
||||
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"],
|
||||
"models.mra": ["MRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MraConfig"],
|
||||
"models.mt5": ["MT5Config"],
|
||||
"models.musicgen": [
|
||||
"MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
@@ -2136,6 +2137,18 @@ else:
|
||||
"MPNetPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mra"].extend(
|
||||
[
|
||||
"MRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"MraForMaskedLM",
|
||||
"MraForMultipleChoice",
|
||||
"MraForQuestionAnswering",
|
||||
"MraForSequenceClassification",
|
||||
"MraForTokenClassification",
|
||||
"MraModel",
|
||||
"MraPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mt5"].extend(
|
||||
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"]
|
||||
)
|
||||
@@ -4276,6 +4289,7 @@ if TYPE_CHECKING:
|
||||
from .models.mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig
|
||||
from .models.mobilevitv2 import MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTV2Config
|
||||
from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer
|
||||
from .models.mra import MRA_PRETRAINED_CONFIG_ARCHIVE_MAP, MraConfig
|
||||
from .models.mt5 import MT5Config
|
||||
from .models.musicgen import (
|
||||
MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@@ -5733,6 +5747,16 @@ if TYPE_CHECKING:
|
||||
MPNetModel,
|
||||
MPNetPreTrainedModel,
|
||||
)
|
||||
from .models.mra import (
|
||||
MRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MraForMaskedLM,
|
||||
MraForMultipleChoice,
|
||||
MraForQuestionAnswering,
|
||||
MraForSequenceClassification,
|
||||
MraForTokenClassification,
|
||||
MraModel,
|
||||
MraPreTrainedModel,
|
||||
)
|
||||
from .models.mt5 import (
|
||||
MT5EncoderModel,
|
||||
MT5ForConditionalGeneration,
|
||||
|
||||
383
src/transformers/kernels/mra/cuda_kernel.cu
Normal file
383
src/transformers/kernels/mra/cuda_kernel.cu
Normal file
@@ -0,0 +1,383 @@
|
||||
#include "cuda_kernel.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__global__ void index_max_cuda_kernel(
|
||||
float *index_vals, // [batch_size, 32, num_block]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *max_vals, // [batch_size, A_num_block * 32]
|
||||
float *max_vals_scatter, // [batch_size, 32, num_block]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long num_block
|
||||
) {
|
||||
|
||||
long batch_idx = blockIdx.x;
|
||||
|
||||
long thread_idx = threadIdx.x;
|
||||
long num_thread = blockDim.x;
|
||||
|
||||
extern __shared__ float buffer[];
|
||||
int *max_buffer = (int*)buffer;
|
||||
|
||||
for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
|
||||
int idx = i + thread_idx;
|
||||
if (idx < A_num_block * 32) {
|
||||
max_buffer[idx] = -1e8;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int *indices_pt = &indices[batch_idx * num_block];
|
||||
float *index_vals_pt = &index_vals[batch_idx * num_block * 32];
|
||||
|
||||
for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
|
||||
int idx = idx_start + thread_idx;
|
||||
int A_block_idx = indices_pt[idx % num_block] / B_num_block;
|
||||
atomicMax(&max_buffer[A_block_idx * 32 + idx / num_block], (int)(index_vals_pt[idx] * 1000));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float *max_vals_pt = &max_vals[batch_idx * A_num_block * 32];
|
||||
for (int i = 0; i < A_num_block * 32; i = i + num_thread) {
|
||||
int idx = i + thread_idx;
|
||||
if (idx < A_num_block * 32) {
|
||||
max_vals_pt[idx] = (float)max_buffer[idx] / 1000.;
|
||||
}
|
||||
}
|
||||
|
||||
float *max_vals_scatter_pt = &max_vals_scatter[batch_idx * num_block * 32];
|
||||
for (int idx_start = 0; idx_start < 32 * num_block; idx_start = idx_start + num_thread) {
|
||||
int idx = idx_start + thread_idx;
|
||||
int A_block_idx = indices_pt[idx % num_block] / B_num_block;
|
||||
max_vals_scatter_pt[idx] = (float)max_buffer[A_block_idx * 32 + idx / num_block] / 1000.;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void mm_to_sparse_cuda_kernel(
|
||||
float *dense_A, // [batch_size, A_num_block, dim, 32]
|
||||
float *dense_B, // [batch_size, B_num_block, dim, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *sparse_C, // [batch_size, num_block, 32, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long dim,
|
||||
long num_block
|
||||
) {
|
||||
|
||||
long batch_idx = blockIdx.y;
|
||||
long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
long thread_idx = threadIdx.x;
|
||||
|
||||
__shared__ float buffer[4096];
|
||||
float *A_buffer = &buffer[threadIdx.y * 1024]; // [2, 8, 32]
|
||||
float *B_buffer = &buffer[threadIdx.y * 1024 + 512]; // [2, 8, 32]
|
||||
|
||||
long batch_idx__block_idx = batch_idx * num_block + block_idx;
|
||||
|
||||
long AB_block_idx = indices[batch_idx__block_idx];
|
||||
float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * dim * 32];
|
||||
float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * dim * 32];
|
||||
|
||||
int reg_1_idx = thread_idx / 8; // [0000000011111111222222223333333344444444555555556666666677777777]
|
||||
int reg_2_idx = thread_idx % 8; // [0123456701234567012345670123456701234567012345670123456701234567]
|
||||
|
||||
float reg_1[8];
|
||||
float reg_2[8];
|
||||
|
||||
float reg_array[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
A_buffer[i * 64 + thread_idx] = dense_A_pt[i * 64 + thread_idx];
|
||||
B_buffer[i * 64 + thread_idx] = dense_B_pt[i * 64 + thread_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg_1[i] = A_buffer[reg_1_idx * 4 + i];
|
||||
reg_2[i] = B_buffer[reg_2_idx * 4 + i];
|
||||
}
|
||||
|
||||
for (int dim_stride = 1; dim_stride < (dim / 8); dim_stride++) {
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
A_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_A_pt[dim_stride * 256 + i * 64 + thread_idx];
|
||||
B_buffer[(dim_stride % 2) * 256 + i * 64 + thread_idx] = dense_B_pt[dim_stride * 256 + i * 64 + thread_idx];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
|
||||
reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[((dim_stride - 1) % 2) * 256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg_1[i] = A_buffer[(dim_stride % 2) * 256 + reg_1_idx * 4 + i];
|
||||
reg_2[i] = B_buffer[(dim_stride % 2) * 256 + reg_2_idx * 4 + i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mini_dim_idx = 1; mini_dim_idx < 8; mini_dim_idx++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg_1[(mini_dim_idx % 2) * 4 + i] = A_buffer[256 + mini_dim_idx * 32 + reg_1_idx * 4 + i];
|
||||
reg_2[(mini_dim_idx % 2) * 4 + i] = B_buffer[256 + mini_dim_idx * 32 + reg_2_idx * 4 + i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float *C_buffer = &buffer[threadIdx.y * 1024]; // [32, 32]
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
C_buffer[(reg_2_idx * 4 + j) * 32 + reg_1_idx * 4 + i] = reg_array[i * 4 + j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float *sparse_C_pt = &sparse_C[batch_idx__block_idx * 1024];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
sparse_C_pt[i * 64 + thread_idx] = C_buffer[i * 64 + thread_idx];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void sparse_dense_mm_cuda_kernel(
|
||||
float *sparse_A, // [batch_size, num_block, 32, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *dense_B, // [batch_size, B_num_block, dim, 32]
|
||||
float *dense_C, // [batch_size, A_num_block, dim, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long dim,
|
||||
long num_block
|
||||
) {
|
||||
|
||||
long batch_idx = blockIdx.y;
|
||||
long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
long thread_idx = threadIdx.x;
|
||||
|
||||
__shared__ float buffer[6144];
|
||||
float *A_buffer = &buffer[threadIdx.y * 3072]; // [32, 32]
|
||||
float *B_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [32, 64]
|
||||
|
||||
long batch_idx__block_idx = batch_idx * num_block + block_idx;
|
||||
|
||||
float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
A_buffer[i * 128 + thread_idx] = sparse_A_pt[i * 128 + thread_idx];
|
||||
}
|
||||
|
||||
long AB_block_idx = indices[batch_idx__block_idx];
|
||||
float *dense_B_pt = &dense_B[(batch_idx * B_num_block + AB_block_idx % B_num_block) * 32 * dim];
|
||||
float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32 * dim];
|
||||
|
||||
// [0000000011111111222222223333333344444444555555556666666677777777]
|
||||
// [0123456701234567012345670123456701234567012345670123456701234567]
|
||||
int reg_1_idx = thread_idx / 8;
|
||||
int reg_2_idx = thread_idx % 8;
|
||||
|
||||
float reg_1[8];
|
||||
float reg_2[8];
|
||||
|
||||
float reg_array[16];
|
||||
|
||||
for (int dim_stride = 0; dim_stride < dim; dim_stride = dim_stride + 64) {
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
B_buffer[i * 128 + thread_idx] = dense_B_pt[dim_stride * 32 + i * 128 + thread_idx];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg_array[i] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg_1[i] = B_buffer[(reg_1_idx * 4 + i) * 32];
|
||||
reg_2[i] = A_buffer[reg_2_idx * 4 + i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int mini_dim_idx = 1; mini_dim_idx < 32; mini_dim_idx++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg_1[(mini_dim_idx % 2) * 4 + i] = B_buffer[(reg_1_idx * 4 + i) * 32 + mini_dim_idx];
|
||||
reg_2[(mini_dim_idx % 2) * 4 + i] = A_buffer[mini_dim_idx * 32 + reg_2_idx * 4 + i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
reg_array[i * 4 + j] += reg_1[((mini_dim_idx - 1) % 2) * 4 + i] * reg_2[((mini_dim_idx - 1) % 2) * 4 + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
reg_array[i * 4 + j] += reg_1[4 + i] * reg_2[4 + j];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float *C_buffer = &buffer[threadIdx.y * 3072 + 1024]; // [64, 32]
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) {
|
||||
C_buffer[(reg_1_idx * 4 + i) * 32 + reg_2_idx * 4 + j] = reg_array[i * 4 + j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
atomicAdd(&dense_C_pt[dim_stride * 32 + i * 128 + thread_idx], C_buffer[i * 128 + thread_idx]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
__global__ void reduce_sum_cuda_kernel(
|
||||
float *sparse_A, // [batch_size, num_block, 32, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *dense_C, // [batch_size, A_num_block, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long num_block
|
||||
) {
|
||||
|
||||
long batch_idx = blockIdx.y;
|
||||
long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
long thread_idx = threadIdx.x;
|
||||
|
||||
long batch_idx__block_idx = batch_idx * num_block + block_idx;
|
||||
|
||||
long AB_block_idx = indices[batch_idx__block_idx];
|
||||
float *sparse_A_pt = &sparse_A[batch_idx__block_idx * 1024];
|
||||
|
||||
float reg_array[16];
|
||||
float value = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg_array[i] = sparse_A_pt[i * 32 + thread_idx];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int stride = 8; stride < 32; stride = stride + 8) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg_array[(stride + i) % 16] = sparse_A_pt[(stride + i) * 32 + thread_idx];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
value = value + reg_array[(stride - 8 + i) % 16];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
value = value + reg_array[8 + i];
|
||||
}
|
||||
|
||||
float *dense_C_pt = &dense_C[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
|
||||
|
||||
atomicAdd(&dense_C_pt[thread_idx], value);
|
||||
|
||||
}
|
||||
|
||||
__global__ void scatter_cuda_kernel(
|
||||
float *dense_A, // [batch_size, A_num_block, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *sparse_C, // [batch_size, num_block, 32, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long num_block
|
||||
) {
|
||||
|
||||
long batch_idx = blockIdx.y;
|
||||
long block_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
long thread_idx = threadIdx.x;
|
||||
|
||||
long batch_idx__block_idx = batch_idx * num_block + block_idx;
|
||||
|
||||
long AB_block_idx = indices[batch_idx__block_idx];
|
||||
float *dense_A_pt = &dense_A[(batch_idx * A_num_block + AB_block_idx / B_num_block) * 32];
|
||||
float *sparse_C_pt = &sparse_C[(batch_idx * num_block + block_idx) * 1024];
|
||||
|
||||
float value = dense_A_pt[thread_idx];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 32; i++) {
|
||||
sparse_C_pt[i * 32 + thread_idx] = value;
|
||||
}
|
||||
|
||||
}
|
||||
59
src/transformers/kernels/mra/cuda_kernel.h
Normal file
59
src/transformers/kernels/mra/cuda_kernel.h
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define FULL_MASK 0xffffffff
|
||||
#define OPTIMAL_THREADS 256
|
||||
|
||||
__global__ void index_max_cuda_kernel(
|
||||
float *index_vals, // [batch_size, 32, num_block]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *max_vals, // [batch_size, A_num_block * 32]
|
||||
float *max_vals_scatter, // [batch_size, 32, num_block]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long num_block
|
||||
);
|
||||
|
||||
__global__ void mm_to_sparse_cuda_kernel(
|
||||
float *dense_A, // [batch_size, A_num_block, dim, 32]
|
||||
float *dense_B, // [batch_size, B_num_block, dim, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *sparse_C, // [batch_size, num_block, 32, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long dim,
|
||||
long num_block
|
||||
);
|
||||
|
||||
__global__ void sparse_dense_mm_cuda_kernel(
|
||||
float *sparse_A, // [batch_size, num_block, 32, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *dense_B, // [batch_size, B_num_block, dim, 32]
|
||||
float *dense_C, // [batch_size, A_num_block, dim, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long dim,
|
||||
long num_block
|
||||
);
|
||||
|
||||
__global__ void reduce_sum_cuda_kernel(
|
||||
float *sparse_A, // [batch_size, num_block, 32, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *dense_C, // [batch_size, A_num_block, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long num_block
|
||||
);
|
||||
|
||||
__global__ void scatter_cuda_kernel(
|
||||
float *dense_A, // [batch_size, A_num_block, 32]
|
||||
int *indices, // [batch_size, num_block]
|
||||
float *sparse_C, // [batch_size, num_block, 32, 32]
|
||||
long batch_size,
|
||||
long A_num_block,
|
||||
long B_num_block,
|
||||
long num_block
|
||||
);
|
||||
154
src/transformers/kernels/mra/cuda_launch.cu
Normal file
154
src/transformers/kernels/mra/cuda_launch.cu
Normal file
@@ -0,0 +1,154 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include "cuda_launch.h"
|
||||
#include "cuda_kernel.h"
|
||||
#include <vector>
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> index_max_kernel(
|
||||
at::Tensor index_vals, // [batch_size, 32, num_block]
|
||||
at::Tensor indices, // [batch_size, num_block],
|
||||
int A_num_block,
|
||||
int B_num_block
|
||||
) {
|
||||
int batch_size = indices.size(0);
|
||||
int num_block = indices.size(1);
|
||||
|
||||
at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options());
|
||||
at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options());
|
||||
|
||||
dim3 threads(256);
|
||||
dim3 blocks(batch_size);
|
||||
int shared_mem = A_num_block * 32 * sizeof(float);
|
||||
|
||||
index_max_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||
index_vals.data_ptr<float>(),
|
||||
indices.data_ptr<int>(),
|
||||
max_vals.data_ptr<float>(),
|
||||
max_vals_scatter.data_ptr<float>(),
|
||||
batch_size,
|
||||
A_num_block,
|
||||
B_num_block,
|
||||
num_block
|
||||
);
|
||||
|
||||
return {max_vals, max_vals_scatter};
|
||||
}
|
||||
|
||||
at::Tensor mm_to_sparse_kernel(
|
||||
at::Tensor dense_A, // [batch_size, A_num_block, dim, 32]
|
||||
at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
|
||||
at::Tensor indices // [batch_size, num_block]
|
||||
) {
|
||||
int batch_size = dense_A.size(0);
|
||||
int A_num_block = dense_A.size(1);
|
||||
int B_num_block = dense_B.size(1);
|
||||
int dim = dense_A.size(2);
|
||||
int num_block = indices.size(1);
|
||||
|
||||
at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
|
||||
|
||||
dim3 threads(64, 4);
|
||||
dim3 blocks(num_block / 4, batch_size);
|
||||
|
||||
mm_to_sparse_cuda_kernel<<<blocks, threads>>>(
|
||||
dense_A.data_ptr<float>(),
|
||||
dense_B.data_ptr<float>(),
|
||||
indices.data_ptr<int>(),
|
||||
sparse_C.data_ptr<float>(),
|
||||
batch_size,
|
||||
A_num_block,
|
||||
B_num_block,
|
||||
dim,
|
||||
num_block
|
||||
);
|
||||
|
||||
return sparse_C;
|
||||
}
|
||||
|
||||
at::Tensor sparse_dense_mm_kernel(
|
||||
at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
|
||||
at::Tensor indices, // [batch_size, num_block]
|
||||
at::Tensor dense_B, // [batch_size, B_num_block, dim, 32]
|
||||
int A_num_block
|
||||
) {
|
||||
int batch_size = sparse_A.size(0);
|
||||
int num_block = sparse_A.size(1);
|
||||
int B_num_block = dense_B.size(1);
|
||||
int dim = dense_B.size(2);
|
||||
|
||||
at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options());
|
||||
|
||||
dim3 threads(128, 2);
|
||||
dim3 blocks(num_block / 2, batch_size);
|
||||
|
||||
sparse_dense_mm_cuda_kernel<<<blocks, threads>>>(
|
||||
sparse_A.data_ptr<float>(),
|
||||
indices.data_ptr<int>(),
|
||||
dense_B.data_ptr<float>(),
|
||||
dense_C.data_ptr<float>(),
|
||||
batch_size,
|
||||
A_num_block,
|
||||
B_num_block,
|
||||
dim,
|
||||
num_block
|
||||
);
|
||||
|
||||
return dense_C;
|
||||
}
|
||||
|
||||
at::Tensor reduce_sum_kernel(
|
||||
at::Tensor sparse_A, // [batch_size, num_block, 32, 32]
|
||||
at::Tensor indices, // [batch_size, num_block]
|
||||
int A_num_block,
|
||||
int B_num_block
|
||||
) {
|
||||
int batch_size = sparse_A.size(0);
|
||||
int num_block = sparse_A.size(1);
|
||||
|
||||
at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options());
|
||||
|
||||
dim3 threads(32, 4);
|
||||
dim3 blocks(num_block / 4, batch_size);
|
||||
|
||||
reduce_sum_cuda_kernel<<<blocks, threads>>>(
|
||||
sparse_A.data_ptr<float>(),
|
||||
indices.data_ptr<int>(),
|
||||
dense_C.data_ptr<float>(),
|
||||
batch_size,
|
||||
A_num_block,
|
||||
B_num_block,
|
||||
num_block
|
||||
);
|
||||
|
||||
return dense_C;
|
||||
}
|
||||
|
||||
at::Tensor scatter_kernel(
|
||||
at::Tensor dense_A, // [batch_size, A_num_block, 32]
|
||||
at::Tensor indices, // [batch_size, num_block]
|
||||
int B_num_block
|
||||
) {
|
||||
int batch_size = dense_A.size(0);
|
||||
int A_num_block = dense_A.size(1);
|
||||
int num_block = indices.size(1);
|
||||
|
||||
at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options());
|
||||
|
||||
dim3 threads(32, 4);
|
||||
dim3 blocks(num_block / 4, batch_size);
|
||||
|
||||
scatter_cuda_kernel<<<blocks, threads>>>(
|
||||
dense_A.data_ptr<float>(),
|
||||
indices.data_ptr<int>(),
|
||||
sparse_C.data_ptr<float>(),
|
||||
batch_size,
|
||||
A_num_block,
|
||||
B_num_block,
|
||||
num_block
|
||||
);
|
||||
|
||||
return sparse_C;
|
||||
}
|
||||
39
src/transformers/kernels/mra/cuda_launch.h
Normal file
39
src/transformers/kernels/mra/cuda_launch.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <vector>
|
||||
|
||||
#define min(a, b) ((a)<(b)?(a):(b))
|
||||
#define max(a, b) ((a)>(b)?(a):(b))
|
||||
|
||||
std::vector<at::Tensor> index_max_kernel(
|
||||
at::Tensor index_vals,
|
||||
at::Tensor indices,
|
||||
int A_num_block,
|
||||
int B_num_block
|
||||
);
|
||||
|
||||
at::Tensor mm_to_sparse_kernel(
|
||||
at::Tensor dense_A,
|
||||
at::Tensor dense_B,
|
||||
at::Tensor indices
|
||||
);
|
||||
|
||||
at::Tensor sparse_dense_mm_kernel(
|
||||
at::Tensor sparse_A,
|
||||
at::Tensor indices,
|
||||
at::Tensor dense_B,
|
||||
int A_num_block
|
||||
);
|
||||
|
||||
at::Tensor reduce_sum_kernel(
|
||||
at::Tensor sparse_A,
|
||||
at::Tensor indices,
|
||||
int A_num_block,
|
||||
int B_num_block
|
||||
);
|
||||
|
||||
at::Tensor scatter_kernel(
|
||||
at::Tensor dense_A,
|
||||
at::Tensor indices,
|
||||
int B_num_block
|
||||
);
|
||||
78
src/transformers/kernels/mra/torch_extension.cpp
Normal file
78
src/transformers/kernels/mra/torch_extension.cpp
Normal file
@@ -0,0 +1,78 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include "cuda_launch.h"
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> index_max(
|
||||
at::Tensor index_vals,
|
||||
at::Tensor indices,
|
||||
int A_num_block,
|
||||
int B_num_block
|
||||
) {
|
||||
return index_max_kernel(
|
||||
index_vals,
|
||||
indices,
|
||||
A_num_block,
|
||||
B_num_block
|
||||
);
|
||||
}
|
||||
|
||||
at::Tensor mm_to_sparse(
|
||||
at::Tensor dense_A,
|
||||
at::Tensor dense_B,
|
||||
at::Tensor indices
|
||||
) {
|
||||
return mm_to_sparse_kernel(
|
||||
dense_A,
|
||||
dense_B,
|
||||
indices
|
||||
);
|
||||
}
|
||||
|
||||
at::Tensor sparse_dense_mm(
|
||||
at::Tensor sparse_A,
|
||||
at::Tensor indices,
|
||||
at::Tensor dense_B,
|
||||
int A_num_block
|
||||
) {
|
||||
return sparse_dense_mm_kernel(
|
||||
sparse_A,
|
||||
indices,
|
||||
dense_B,
|
||||
A_num_block
|
||||
);
|
||||
}
|
||||
|
||||
at::Tensor reduce_sum(
|
||||
at::Tensor sparse_A,
|
||||
at::Tensor indices,
|
||||
int A_num_block,
|
||||
int B_num_block
|
||||
) {
|
||||
return reduce_sum_kernel(
|
||||
sparse_A,
|
||||
indices,
|
||||
A_num_block,
|
||||
B_num_block
|
||||
);
|
||||
}
|
||||
|
||||
at::Tensor scatter(
|
||||
at::Tensor dense_A,
|
||||
at::Tensor indices,
|
||||
int B_num_block
|
||||
) {
|
||||
return scatter_kernel(
|
||||
dense_A,
|
||||
indices,
|
||||
B_num_block
|
||||
);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("index_max", &index_max, "index_max (CUDA)");
|
||||
m.def("mm_to_sparse", &mm_to_sparse, "mm_to_sparse (CUDA)");
|
||||
m.def("sparse_dense_mm", &sparse_dense_mm, "sparse_dense_mm (CUDA)");
|
||||
m.def("reduce_sum", &reduce_sum, "reduce_sum (CUDA)");
|
||||
m.def("scatter", &scatter, "scatter (CUDA)");
|
||||
}
|
||||
@@ -134,6 +134,7 @@ from . import (
|
||||
mobilevit,
|
||||
mobilevitv2,
|
||||
mpnet,
|
||||
mra,
|
||||
mt5,
|
||||
musicgen,
|
||||
mvp,
|
||||
|
||||
@@ -137,6 +137,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTConfig"),
|
||||
("mobilevitv2", "MobileViTV2Config"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mra", "MraConfig"),
|
||||
("mt5", "MT5Config"),
|
||||
("musicgen", "MusicgenConfig"),
|
||||
("mvp", "MvpConfig"),
|
||||
@@ -329,6 +330,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mobilevitv2", "MOBILEVITV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mra", "MRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("musicgen", "MUSICGEN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mvp", "MVP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nat", "NAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@@ -534,6 +536,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mobilevit", "MobileViT"),
|
||||
("mobilevitv2", "MobileViTV2"),
|
||||
("mpnet", "MPNet"),
|
||||
("mra", "MRA"),
|
||||
("mt5", "MT5"),
|
||||
("musicgen", "MusicGen"),
|
||||
("mvp", "MVP"),
|
||||
|
||||
@@ -134,6 +134,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mobilevit", "MobileViTModel"),
|
||||
("mobilevitv2", "MobileViTV2Model"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("mra", "MraModel"),
|
||||
("mt5", "MT5Model"),
|
||||
("mvp", "MvpModel"),
|
||||
("nat", "NatModel"),
|
||||
@@ -247,6 +248,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
("mvp", "MvpForConditionalGeneration"),
|
||||
("nezha", "NezhaForPreTraining"),
|
||||
("nllb-moe", "NllbMoeForConditionalGeneration"),
|
||||
@@ -326,6 +328,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
("mvp", "MvpForConditionalGeneration"),
|
||||
("nezha", "NezhaForMaskedLM"),
|
||||
("nllb-moe", "NllbMoeForConditionalGeneration"),
|
||||
@@ -572,6 +575,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForMaskedLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
("mvp", "MvpForConditionalGeneration"),
|
||||
("nezha", "NezhaForMaskedLM"),
|
||||
("nystromformer", "NystromformerForMaskedLM"),
|
||||
@@ -704,6 +708,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForSequenceClassification"),
|
||||
("mobilebert", "MobileBertForSequenceClassification"),
|
||||
("mpnet", "MPNetForSequenceClassification"),
|
||||
("mra", "MraForSequenceClassification"),
|
||||
("mvp", "MvpForSequenceClassification"),
|
||||
("nezha", "NezhaForSequenceClassification"),
|
||||
("nystromformer", "NystromformerForSequenceClassification"),
|
||||
@@ -771,6 +776,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||
("mpnet", "MPNetForQuestionAnswering"),
|
||||
("mra", "MraForQuestionAnswering"),
|
||||
("mt5", "MT5ForQuestionAnswering"),
|
||||
("mvp", "MvpForQuestionAnswering"),
|
||||
("nezha", "NezhaForQuestionAnswering"),
|
||||
@@ -856,6 +862,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||
("mobilebert", "MobileBertForTokenClassification"),
|
||||
("mpnet", "MPNetForTokenClassification"),
|
||||
("mra", "MraForTokenClassification"),
|
||||
("nezha", "NezhaForTokenClassification"),
|
||||
("nystromformer", "NystromformerForTokenClassification"),
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
@@ -899,6 +906,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
("megatron-bert", "MegatronBertForMultipleChoice"),
|
||||
("mobilebert", "MobileBertForMultipleChoice"),
|
||||
("mpnet", "MPNetForMultipleChoice"),
|
||||
("mra", "MraForMultipleChoice"),
|
||||
("nezha", "NezhaForMultipleChoice"),
|
||||
("nystromformer", "NystromformerForMultipleChoice"),
|
||||
("qdqbert", "QDQBertForMultipleChoice"),
|
||||
|
||||
@@ -214,6 +214,7 @@ else:
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"mt5",
|
||||
(
|
||||
|
||||
68
src/transformers/models/mra/__init__.py
Normal file
68
src/transformers/models/mra/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_mra": ["MRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MraConfig"]}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_mra"] = [
|
||||
"MRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"MraForMaskedLM",
|
||||
"MraForMultipleChoice",
|
||||
"MraForQuestionAnswering",
|
||||
"MraForSequenceClassification",
|
||||
"MraForTokenClassification",
|
||||
"MraLayer",
|
||||
"MraModel",
|
||||
"MraPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mra import MRA_PRETRAINED_CONFIG_ARCHIVE_MAP, MraConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_mra import (
|
||||
MRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MraForMaskedLM,
|
||||
MraForMultipleChoice,
|
||||
MraForQuestionAnswering,
|
||||
MraForSequenceClassification,
|
||||
MraForTokenClassification,
|
||||
MraLayer,
|
||||
MraModel,
|
||||
MraPreTrainedModel,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
137
src/transformers/models/mra/configuration_mra.py
Normal file
137
src/transformers/models/mra/configuration_mra.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MRA model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"uw-madison/mra-base-512-4": "https://huggingface.co/uw-madison/mra-base-512-4/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class MraConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MraModel`]. It is used to instantiate an MRA
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Mra
|
||||
[uw-madison/mra-base-512-4](https://huggingface.co/uw-madison/mra-base-512-4) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50265):
|
||||
Vocabulary size of the Mra model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MraModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (`int`, *optional*, defaults to 1):
|
||||
The vocabulary size of the `token_type_ids` passed when calling [`MraModel`].
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`.
|
||||
block_per_row (`int`, *optional*, defaults to 4):
|
||||
Used to set the budget for the high resolution scale.
|
||||
approx_mode (`str`, *optional*, defaults to `"full"`):
|
||||
Controls whether both low and high resolution approximations are used. Set to `"full"` for both low and
|
||||
high resolution and `"sparse"` for only low resolution.
|
||||
initial_prior_first_n_blocks (`int`, *optional*, defaults to 0):
|
||||
The initial number of blocks for which high resolution is used.
|
||||
initial_prior_diagonal_n_blocks (`int`, *optional*, defaults to 0):
|
||||
The number of diagonal blocks for which high resolution is used.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MraConfig, MraModel
|
||||
|
||||
>>> # Initializing a Mra uw-madison/mra-base-512-4 style configuration
|
||||
>>> configuration = MraConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the uw-madison/mra-base-512-4 style configuration
|
||||
>>> model = MraModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "mra"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=1,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
position_embedding_type="absolute",
|
||||
block_per_row=4,
|
||||
approx_mode="full",
|
||||
initial_prior_first_n_blocks=0,
|
||||
initial_prior_diagonal_n_blocks=0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.block_per_row = block_per_row
|
||||
self.approx_mode = approx_mode
|
||||
self.initial_prior_first_n_blocks = initial_prior_first_n_blocks
|
||||
self.initial_prior_diagonal_n_blocks = initial_prior_diagonal_n_blocks
|
||||
110
src/transformers/models/mra/convert_mra_pytorch_to_pytorch.py
Normal file
110
src/transformers/models/mra/convert_mra_pytorch_to_pytorch.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert MRA checkpoints from the original repository. URL: https://github.com/mlpen/mra-attention"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import MraConfig, MraForMaskedLM
|
||||
|
||||
|
||||
def rename_key(orig_key):
|
||||
if "model" in orig_key:
|
||||
orig_key = orig_key.replace("model.", "")
|
||||
if "norm1" in orig_key:
|
||||
orig_key = orig_key.replace("norm1", "attention.output.LayerNorm")
|
||||
if "norm2" in orig_key:
|
||||
orig_key = orig_key.replace("norm2", "output.LayerNorm")
|
||||
if "norm" in orig_key:
|
||||
orig_key = orig_key.replace("norm", "LayerNorm")
|
||||
if "transformer" in orig_key:
|
||||
layer_num = orig_key.split(".")[0].split("_")[-1]
|
||||
orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}")
|
||||
if "mha.attn" in orig_key:
|
||||
orig_key = orig_key.replace("mha.attn", "attention.self")
|
||||
if "mha" in orig_key:
|
||||
orig_key = orig_key.replace("mha", "attention")
|
||||
if "W_q" in orig_key:
|
||||
orig_key = orig_key.replace("W_q", "self.query")
|
||||
if "W_k" in orig_key:
|
||||
orig_key = orig_key.replace("W_k", "self.key")
|
||||
if "W_v" in orig_key:
|
||||
orig_key = orig_key.replace("W_v", "self.value")
|
||||
if "ff.0" in orig_key:
|
||||
orig_key = orig_key.replace("ff.0", "intermediate.dense")
|
||||
if "ff.2" in orig_key:
|
||||
orig_key = orig_key.replace("ff.2", "output.dense")
|
||||
if "ff" in orig_key:
|
||||
orig_key = orig_key.replace("ff", "output.dense")
|
||||
if "mlm_class" in orig_key:
|
||||
orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder")
|
||||
if "mlm" in orig_key:
|
||||
orig_key = orig_key.replace("mlm", "cls.predictions.transform")
|
||||
if "backbone.backbone.encoders" in orig_key:
|
||||
orig_key = orig_key.replace("backbone.backbone.encoders", "encoder.layer")
|
||||
if "cls" not in orig_key:
|
||||
orig_key = "mra." + orig_key
|
||||
|
||||
return orig_key
|
||||
|
||||
|
||||
def convert_checkpoint_helper(max_position_embeddings, orig_state_dict):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if ("pooler" in key) or ("sen_class" in key):
|
||||
continue
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"]
|
||||
orig_state_dict["mra.embeddings.position_ids"] = torch.arange(max_position_embeddings).expand((1, -1)) + 2
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_mra_checkpoint(checkpoint_path, mra_config_file, pytorch_dump_path):
|
||||
orig_state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
||||
config = MraConfig.from_json_file(mra_config_file)
|
||||
model = MraForMaskedLM(config)
|
||||
|
||||
new_state_dict = convert_checkpoint_helper(config.max_position_embeddings, orig_state_dict)
|
||||
|
||||
print(model.load_state_dict(new_state_dict))
|
||||
model.eval()
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--pytorch_model_path", default=None, type=str, required=True, help="Path to Mra pytorch checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The json file for Mra model config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_mra_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path)
|
||||
1501
src/transformers/models/mra/modeling_mra.py
Normal file
1501
src/transformers/models/mra/modeling_mra.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4901,6 +4901,58 @@ class MPNetPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
MRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class MraForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MraForMultipleChoice(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MraForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MraForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MraForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MraModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MraPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MT5EncoderModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
0
tests/models/mra/__init__.py
Normal file
0
tests/models/mra/__init__.py
Normal file
406
tests/models/mra/test_modeling_mra.py
Normal file
406
tests/models/mra/test_modeling_mra.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch MRA model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MraConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MraForMaskedLM,
|
||||
MraForMultipleChoice,
|
||||
MraForQuestionAnswering,
|
||||
MraForSequenceClassification,
|
||||
MraForTokenClassification,
|
||||
MraModel,
|
||||
)
|
||||
from transformers.models.mra.modeling_mra import MRA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class MraModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
seq_length=8,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=2,
|
||||
intermediate_size=36,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return MraConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MraModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = MraModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MraForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = MraForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = MraForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = MraForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = MraForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
MraModel,
|
||||
MraForMaskedLM,
|
||||
MraForMultipleChoice,
|
||||
MraForQuestionAnswering,
|
||||
MraForSequenceClassification,
|
||||
MraForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
has_attentions = False
|
||||
|
||||
all_generative_model_classes = ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MraModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MraConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in MRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = MraModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="MRA does not output attentions")
|
||||
def test_attention_outputs(self):
|
||||
return
|
||||
|
||||
|
||||
@require_torch
|
||||
class MraModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = MraModel.from_pretrained("uw-madison/mra-base-512-4")
|
||||
input_ids = torch.arange(256).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 256, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.0140, 0.0830, -0.0381], [0.1546, 0.1402, 0.0220], [0.1162, 0.0851, 0.0165]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = MraForMaskedLM.from_pretrained("uw-madison/mra-base-512-4")
|
||||
input_ids = torch.arange(256).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 50265
|
||||
|
||||
expected_shape = torch.Size((1, 256, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[9.2595, -3.6038, 11.8819], [9.3869, -3.2693, 11.0956], [11.8524, -3.4938, 13.1210]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm_long_input(self):
|
||||
model = MraForMaskedLM.from_pretrained("uw-madison/mra-base-4096-8-d3")
|
||||
input_ids = torch.arange(4096).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 50265
|
||||
|
||||
expected_shape = torch.Size((1, 4096, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[5.4789, -2.3564, 7.5064], [7.9067, -1.3369, 9.9668], [9.0712, -1.8106, 7.0380]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
Reference in New Issue
Block a user