Add YOSO (#15091)
* Add cookiecutter files * Add cuda kernels and cpp files * Update modeling_yoso.py * Add .h files * Update configuration_yoso.py * Updates * Remove tokenizer * Code quality * Update modeling_yoso.py * Update modeling_yoso.py * Fix failing test * Update modeling_yoso.py * Fix code quality * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review and fix integration tests * Update src/transformers/models/yoso/modeling_yoso.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply suggestions from code review * Fix copied from statement * Fix docstring * Fix code quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions and fix mask * Apply suggestions from code review * Fix code quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix docstrings * Fix code quality * Remove trailing whitespace * Update yoso.mdx * Move kernel loading to YosoEncoder * make style * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/yoso/modeling_yoso.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add short summary to docs * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update yoso.mdx * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Remove CausalLM model and add copied from * Remove autoregressive code * Remove unused imports * add copied from for embeddings * Fix code quality * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestion from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -325,6 +325,7 @@ AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Ch
|
|||||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||||
|
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||||
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||||
|
|
||||||
To check if each model has an implementation in Flax, PyTorch or TensorFlow, or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/docs/transformers/index#supported-frameworks).
|
To check if each model has an implementation in Flax, PyTorch or TensorFlow, or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/docs/transformers/index#supported-frameworks).
|
||||||
|
|||||||
@@ -303,6 +303,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
|||||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||||
1. 새로운 모델을 올리고 싶나요? 우리가 **상세한 가이드와 템플릿** 으로 새로운 모델을 올리도록 도와드릴게요. 가이드와 템플릿은 이 저장소의 [`templates`](./templates) 폴더에서 확인하실 수 있습니다. [컨트리뷰션 가이드라인](./CONTRIBUTING.md)을 꼭 확인해주시고, PR을 올리기 전에 메인테이너에게 연락하거나 이슈를 오픈해 피드백을 받으시길 바랍니다.
|
1. 새로운 모델을 올리고 싶나요? 우리가 **상세한 가이드와 템플릿** 으로 새로운 모델을 올리도록 도와드릴게요. 가이드와 템플릿은 이 저장소의 [`templates`](./templates) 폴더에서 확인하실 수 있습니다. [컨트리뷰션 가이드라인](./CONTRIBUTING.md)을 꼭 확인해주시고, PR을 올리기 전에 메인테이너에게 연락하거나 이슈를 오픈해 피드백을 받으시길 바랍니다.
|
||||||
|
|
||||||
각 모델이 Flax, PyTorch, TensorFlow으로 구현되었는지 또는 🤗 Tokenizers 라이브러리가 지원하는 토크나이저를 사용하는지 확인하려면, [이 표](https://huggingface.co/docs/transformers/index#supported-frameworks)를 확인하세요.
|
각 모델이 Flax, PyTorch, TensorFlow으로 구현되었는지 또는 🤗 Tokenizers 라이브러리가 지원하는 토크나이저를 사용하는지 확인하려면, [이 표](https://huggingface.co/docs/transformers/index#supported-frameworks)를 확인하세요.
|
||||||
|
|||||||
@@ -327,6 +327,7 @@ conda install -c huggingface transformers
|
|||||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (来自 Google/CMU) 伴随论文 [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) 由 Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 发布。
|
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (来自 Google/CMU) 伴随论文 [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) 由 Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 发布。
|
||||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (来自 Facebook AI) 伴随论文 [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) 由 Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli 发布。
|
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (来自 Facebook AI) 伴随论文 [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) 由 Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli 发布。
|
||||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (来自 Facebook AI) 伴随论文 [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) 由 Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli 发布。
|
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (来自 Facebook AI) 伴随论文 [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) 由 Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli 发布。
|
||||||
|
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (来自 the University of Wisconsin - Madison) 伴随论文 [You Only Sample (Almost) 由 Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh 发布。
|
||||||
1. 想要贡献新的模型?我们这里有一份**详细指引和模板**来引导你添加新的模型。你可以在 [`templates`](./templates) 目录中找到他们。记得查看 [贡献指南](./CONTRIBUTING.md) 并在开始写 PR 前联系维护人员或开一个新的 issue 来获得反馈。
|
1. 想要贡献新的模型?我们这里有一份**详细指引和模板**来引导你添加新的模型。你可以在 [`templates`](./templates) 目录中找到他们。记得查看 [贡献指南](./CONTRIBUTING.md) 并在开始写 PR 前联系维护人员或开一个新的 issue 来获得反馈。
|
||||||
|
|
||||||
要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器(tokenizer),敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器(tokenizer),敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||||
|
|||||||
@@ -339,6 +339,7 @@ conda install -c huggingface transformers
|
|||||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||||
1. 想要貢獻新的模型?我們這裡有一份**詳細指引和模板**來引導你加入新的模型。你可以在 [`templates`](./templates) 目錄中找到它們。記得查看[貢獻指引](./CONTRIBUTING.md)並在開始寫 PR 前聯繫維護人員或開一個新的 issue 來獲得 feedbacks。
|
1. 想要貢獻新的模型?我們這裡有一份**詳細指引和模板**來引導你加入新的模型。你可以在 [`templates`](./templates) 目錄中找到它們。記得查看[貢獻指引](./CONTRIBUTING.md)並在開始寫 PR 前聯繫維護人員或開一個新的 issue 來獲得 feedbacks。
|
||||||
|
|
||||||
要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer,敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer,敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||||
|
|||||||
@@ -316,6 +316,8 @@
|
|||||||
title: XLSR-Wav2Vec2
|
title: XLSR-Wav2Vec2
|
||||||
- local: model_doc/xls_r
|
- local: model_doc/xls_r
|
||||||
title: XLS-R
|
title: XLS-R
|
||||||
|
- local: model_doc/yoso
|
||||||
|
title: YOSO
|
||||||
title: Models
|
title: Models
|
||||||
- sections:
|
- sections:
|
||||||
- local: internal/modeling_utils
|
- local: internal/modeling_utils
|
||||||
|
|||||||
@@ -184,6 +184,7 @@ conversion utilities for the following models.
|
|||||||
1. **[XLNet](model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
1. **[XLNet](model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
1. **[XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
1. **[XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||||
|
1. **[YOSO](model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||||
|
|
||||||
|
|
||||||
### Supported frameworks
|
### Supported frameworks
|
||||||
@@ -281,5 +282,6 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
| YOSO | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
|
||||||
<!-- End table-->
|
<!-- End table-->
|
||||||
|
|||||||
91
docs/source/model_doc/yoso.mdx
Normal file
91
docs/source/model_doc/yoso.mdx
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# YOSO
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The YOSO model was proposed in [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714)
|
||||||
|
by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. YOSO approximates standard softmax self-attention
|
||||||
|
via a Bernoulli sampling scheme based on Locality Sensitive Hashing (LSH). In principle, all the Bernoulli random variables can be sampled with
|
||||||
|
a single hash.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*Transformer-based models are widely used in natural language processing (NLP). Central to the transformer model is
|
||||||
|
the self-attention mechanism, which captures the interactions of token pairs in the input sequences and depends quadratically
|
||||||
|
on the sequence length. Training such models on longer sequences is expensive. In this paper, we show that a Bernoulli sampling
|
||||||
|
attention mechanism based on Locality Sensitive Hashing (LSH), decreases the quadratic complexity of such models to linear.
|
||||||
|
We bypass the quadratic cost by considering self-attention as a sum of individual tokens associated with Bernoulli random
|
||||||
|
variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant).
|
||||||
|
This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of
|
||||||
|
LSH (to enable deployment on GPU architectures). We evaluate our algorithm on the GLUE benchmark with standard 512 sequence
|
||||||
|
length where we see favorable performance relative to a standard pretrained Transformer. On the Long Range Arena (LRA) benchmark,
|
||||||
|
for evaluating performance on long sequences, our method achieves results consistent with softmax self-attention but with sizable
|
||||||
|
speed-ups and memory savings and often outperforms other efficient self-attention methods. Our code is available at this https URL*
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- The YOSO attention algorithm is implemented through custom CUDA kernels, functions written in CUDA C++ that can be executed multiple times
|
||||||
|
in parallel on a GPU.
|
||||||
|
- The kernels provide a `fast_hash` function, which approximates the random projections of the queries and keys using the Fast Hadamard Transform. Using these
|
||||||
|
hash codes, the `lsh_cumulation` function approximates self-attention via LSH-based Bernoulli sampling.
|
||||||
|
- To use the custom kernels, the user should set `config.use_expectation = False`. To ensure that the kernels are compiled successfully,
|
||||||
|
the user must install the correct version of PyTorch and cudatoolkit. By default, `config.use_expectation = True`, which uses YOSO-E and
|
||||||
|
does not require compiling CUDA kernels.
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/yoso_architecture.jpg"
|
||||||
|
alt="drawing" width="600"/>
|
||||||
|
|
||||||
|
<small> YOSO Attention Algorithm. Taken from the <a href="https://arxiv.org/abs/2111.09714">original paper</a>.</small>
|
||||||
|
|
||||||
|
This model was contributed by [novice03](https://huggingface.co/novice03). The original code can be found [here](https://github.com/mlpen/YOSO).
|
||||||
|
|
||||||
|
|
||||||
|
## YosoConfig
|
||||||
|
|
||||||
|
[[autodoc]] YosoConfig
|
||||||
|
|
||||||
|
|
||||||
|
## YosoModel
|
||||||
|
|
||||||
|
[[autodoc]] YosoModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## YosoForMaskedLM
|
||||||
|
|
||||||
|
[[autodoc]] YosoForMaskedLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## YosoForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] YosoForSequenceClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## YosoForMultipleChoice
|
||||||
|
|
||||||
|
[[autodoc]] YosoForMultipleChoice
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## YosoForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] YosoForTokenClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
## YosoForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] YosoForQuestionAnswering
|
||||||
|
- forward
|
||||||
@@ -333,6 +333,7 @@ _import_structure = {
|
|||||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||||
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
||||||
|
"models.yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
|
||||||
"onnx": [],
|
"onnx": [],
|
||||||
"pipelines": [
|
"pipelines": [
|
||||||
"AudioClassificationPipeline",
|
"AudioClassificationPipeline",
|
||||||
@@ -1510,6 +1511,19 @@ if is_torch_available():
|
|||||||
"load_tf_weights_in_xlnet",
|
"load_tf_weights_in_xlnet",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.yoso"].extend(
|
||||||
|
[
|
||||||
|
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"YosoForMaskedLM",
|
||||||
|
"YosoForMultipleChoice",
|
||||||
|
"YosoForQuestionAnswering",
|
||||||
|
"YosoForSequenceClassification",
|
||||||
|
"YosoForTokenClassification",
|
||||||
|
"YosoLayer",
|
||||||
|
"YosoModel",
|
||||||
|
"YosoPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["optimization"] = [
|
_import_structure["optimization"] = [
|
||||||
"Adafactor",
|
"Adafactor",
|
||||||
"AdamW",
|
"AdamW",
|
||||||
@@ -2454,6 +2468,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
||||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||||
from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||||
|
from .models.yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
|
||||||
|
|
||||||
# Pipelines
|
# Pipelines
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
@@ -3431,6 +3446,17 @@ if TYPE_CHECKING:
|
|||||||
XLNetPreTrainedModel,
|
XLNetPreTrainedModel,
|
||||||
load_tf_weights_in_xlnet,
|
load_tf_weights_in_xlnet,
|
||||||
)
|
)
|
||||||
|
from .models.yoso import (
|
||||||
|
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
YosoForMaskedLM,
|
||||||
|
YosoForMultipleChoice,
|
||||||
|
YosoForQuestionAnswering,
|
||||||
|
YosoForSequenceClassification,
|
||||||
|
YosoForTokenClassification,
|
||||||
|
YosoLayer,
|
||||||
|
YosoModel,
|
||||||
|
YosoPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
from .optimization import (
|
from .optimization import (
|
||||||
|
|||||||
@@ -119,4 +119,5 @@ from . import (
|
|||||||
xlm_prophetnet,
|
xlm_prophetnet,
|
||||||
xlm_roberta,
|
xlm_roberta,
|
||||||
xlnet,
|
xlnet,
|
||||||
|
yoso,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ logger = logging.get_logger(__name__)
|
|||||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Add configs here
|
# Add configs here
|
||||||
|
("yoso", "YosoConfig"),
|
||||||
("swin", "SwinConfig"),
|
("swin", "SwinConfig"),
|
||||||
("vilt", "ViltConfig"),
|
("vilt", "ViltConfig"),
|
||||||
("vit_mae", "ViTMAEConfig"),
|
("vit_mae", "ViTMAEConfig"),
|
||||||
@@ -121,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Add archive maps here
|
# Add archive maps here
|
||||||
|
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
@@ -200,6 +202,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_NAMES_MAPPING = OrderedDict(
|
MODEL_NAMES_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Add full (and cased) model names here
|
# Add full (and cased) model names here
|
||||||
|
("yoso", "YOSO"),
|
||||||
("swin", "Swin"),
|
("swin", "Swin"),
|
||||||
("vilt", "ViLT"),
|
("vilt", "ViLT"),
|
||||||
("vit_mae", "ViTMAE"),
|
("vit_mae", "ViTMAE"),
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
|||||||
MODEL_MAPPING_NAMES = OrderedDict(
|
MODEL_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Base model mapping
|
# Base model mapping
|
||||||
|
("yoso", "YosoModel"),
|
||||||
("swin", "SwinModel"),
|
("swin", "SwinModel"),
|
||||||
("vilt", "ViltModel"),
|
("vilt", "ViltModel"),
|
||||||
("vit_mae", "ViTMAEModel"),
|
("vit_mae", "ViTMAEModel"),
|
||||||
@@ -155,6 +156,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model with LM heads mapping
|
# Model with LM heads mapping
|
||||||
|
("yoso", "YosoForMaskedLM"),
|
||||||
("nystromformer", "NystromformerForMaskedLM"),
|
("nystromformer", "NystromformerForMaskedLM"),
|
||||||
("qdqbert", "QDQBertForMaskedLM"),
|
("qdqbert", "QDQBertForMaskedLM"),
|
||||||
("fnet", "FNetForMaskedLM"),
|
("fnet", "FNetForMaskedLM"),
|
||||||
@@ -284,6 +286,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Masked LM mapping
|
# Model for Masked LM mapping
|
||||||
|
("yoso", "YosoForMaskedLM"),
|
||||||
("nystromformer", "NystromformerForMaskedLM"),
|
("nystromformer", "NystromformerForMaskedLM"),
|
||||||
("perceiver", "PerceiverForMaskedLM"),
|
("perceiver", "PerceiverForMaskedLM"),
|
||||||
("qdqbert", "QDQBertForMaskedLM"),
|
("qdqbert", "QDQBertForMaskedLM"),
|
||||||
@@ -357,6 +360,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Sequence Classification mapping
|
# Model for Sequence Classification mapping
|
||||||
|
("yoso", "YosoForSequenceClassification"),
|
||||||
("nystromformer", "NystromformerForSequenceClassification"),
|
("nystromformer", "NystromformerForSequenceClassification"),
|
||||||
("perceiver", "PerceiverForSequenceClassification"),
|
("perceiver", "PerceiverForSequenceClassification"),
|
||||||
("qdqbert", "QDQBertForSequenceClassification"),
|
("qdqbert", "QDQBertForSequenceClassification"),
|
||||||
@@ -405,6 +409,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Question Answering mapping
|
# Model for Question Answering mapping
|
||||||
|
("yoso", "YosoForQuestionAnswering"),
|
||||||
("nystromformer", "NystromformerForQuestionAnswering"),
|
("nystromformer", "NystromformerForQuestionAnswering"),
|
||||||
("qdqbert", "QDQBertForQuestionAnswering"),
|
("qdqbert", "QDQBertForQuestionAnswering"),
|
||||||
("fnet", "FNetForQuestionAnswering"),
|
("fnet", "FNetForQuestionAnswering"),
|
||||||
@@ -454,6 +459,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Token Classification mapping
|
# Model for Token Classification mapping
|
||||||
|
("yoso", "YosoForTokenClassification"),
|
||||||
("nystromformer", "NystromformerForTokenClassification"),
|
("nystromformer", "NystromformerForTokenClassification"),
|
||||||
("qdqbert", "QDQBertForTokenClassification"),
|
("qdqbert", "QDQBertForTokenClassification"),
|
||||||
("fnet", "FNetForTokenClassification"),
|
("fnet", "FNetForTokenClassification"),
|
||||||
@@ -490,6 +496,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Multiple Choice mapping
|
# Model for Multiple Choice mapping
|
||||||
|
("yoso", "YosoForMultipleChoice"),
|
||||||
("nystromformer", "NystromformerForMultipleChoice"),
|
("nystromformer", "NystromformerForMultipleChoice"),
|
||||||
("qdqbert", "QDQBertForMultipleChoice"),
|
("qdqbert", "QDQBertForMultipleChoice"),
|
||||||
("fnet", "FNetForMultipleChoice"),
|
("fnet", "FNetForMultipleChoice"),
|
||||||
|
|||||||
62
src/transformers/models/yoso/__init__.py
Normal file
62
src/transformers/models/yoso/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
# rely on isort to merge the imports
|
||||||
|
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {
|
||||||
|
"configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
_import_structure["modeling_yoso"] = [
|
||||||
|
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"YosoForMaskedLM",
|
||||||
|
"YosoForMultipleChoice",
|
||||||
|
"YosoForQuestionAnswering",
|
||||||
|
"YosoForSequenceClassification",
|
||||||
|
"YosoForTokenClassification",
|
||||||
|
"YosoLayer",
|
||||||
|
"YosoModel",
|
||||||
|
"YosoPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .modeling_yoso import (
|
||||||
|
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
YosoForMaskedLM,
|
||||||
|
YosoForMultipleChoice,
|
||||||
|
YosoForQuestionAnswering,
|
||||||
|
YosoForSequenceClassification,
|
||||||
|
YosoForTokenClassification,
|
||||||
|
YosoLayer,
|
||||||
|
YosoModel,
|
||||||
|
YosoPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||||
10
src/transformers/models/yoso/common.h
Normal file
10
src/transformers/models/yoso/common.h
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
#define min(a, b) ((a)<(b)?(a):(b))
|
||||||
|
#define max(a, b) ((a)>(b)?(a):(b))
|
||||||
|
#define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))
|
||||||
|
#define select(cond, a, b) ((cond)?(a):(b))
|
||||||
|
#define PI 3.141592
|
||||||
|
#define EPSILON 1e-8
|
||||||
|
#define MAX_VAL 1e12
|
||||||
|
#define MIN_VAL -1e12
|
||||||
|
#define EMPTY_VALUE -1
|
||||||
9
src/transformers/models/yoso/common_cuda.h
Normal file
9
src/transformers/models/yoso/common_cuda.h
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
|
||||||
|
#define MAX_THREADS_PER_BLOCK 1024
|
||||||
|
#define OPTIMAL_THREADS_PER_BLOCK 256
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
#define MAX_NUM_BLOCK_X 2147483647
|
||||||
|
#define MAX_NUM_BLOCK_Y 65535
|
||||||
|
#define MAX_NUM_BLOCK_Z 65535
|
||||||
|
#define MAX_SHARED_MEM_PER_BLOCK 48000
|
||||||
|
#define FULL_MASK 0xffffffff
|
||||||
79
src/transformers/models/yoso/common_cuda_device.h
Normal file
79
src/transformers/models/yoso/common_cuda_device.h
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ int set_insert(T *set, int set_size, T value) {
|
||||||
|
int slot = value % set_size;
|
||||||
|
int start_slot = slot;
|
||||||
|
while (true) {
|
||||||
|
T prev = atomicCAS(&set[slot], EMPTY_VALUE, value);
|
||||||
|
if (prev == EMPTY_VALUE || prev == value) {
|
||||||
|
return slot;
|
||||||
|
}
|
||||||
|
slot = (slot + 1) % set_size;
|
||||||
|
if (slot == start_slot) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ int set_lookup(T *set, int set_size, T value) {
|
||||||
|
int slot = value % set_size;
|
||||||
|
int start_slot = slot;
|
||||||
|
while (true) {
|
||||||
|
if (set[slot] == value) {
|
||||||
|
return slot;
|
||||||
|
}
|
||||||
|
slot = (slot + 1) % set_size;
|
||||||
|
if (slot == start_slot) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ void init_buffer(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
|
||||||
|
__syncthreads();
|
||||||
|
for (int i = 0; i < buffer_size; i = i + num_threads) {
|
||||||
|
int offset_idx = i + thread_id;
|
||||||
|
if (offset_idx < buffer_size) {
|
||||||
|
buffer[offset_idx] = init_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ void copy_data(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
|
||||||
|
__syncthreads();
|
||||||
|
for (int i = 0; i < data_length; i = i + num_threads) {
|
||||||
|
int offset_idx = i + thread_id;
|
||||||
|
if (offset_idx < data_length) {
|
||||||
|
dist_pt[offset_idx] = src_pt[offset_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ void init_buffer_nonblocking(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
|
||||||
|
for (int i = 0; i < buffer_size; i = i + num_threads) {
|
||||||
|
int offset_idx = i + thread_id;
|
||||||
|
if (offset_idx < buffer_size) {
|
||||||
|
buffer[offset_idx] = init_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ void copy_data_nonblocking(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
|
||||||
|
for (int i = 0; i < data_length; i = i + num_threads) {
|
||||||
|
int offset_idx = i + thread_id;
|
||||||
|
if (offset_idx < data_length) {
|
||||||
|
dist_pt[offset_idx] = src_pt[offset_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
145
src/transformers/models/yoso/configuration_yoso.py
Normal file
145
src/transformers/models/yoso/configuration_yoso.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" YOSO model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"uw-madison/yoso-4096": "https://huggingface.co/uw-madison/yoso-4096/resolve/main/config.json",
|
||||||
|
# See all YOSO models at https://huggingface.co/models?filter=yoso
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class YosoConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`YosoModel`]. It is used to instantiate an YOSO
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the YOSO
|
||||||
|
[uw-madison/yoso-4096](https://huggingface.co/uw-madison/yoso-4096) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 50265):
|
||||||
|
Vocabulary size of the YOSO model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`YosoModel`].
|
||||||
|
hidden_size (`int`, *optional*, defaults to 768):
|
||||||
|
Dimension of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||||
|
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||||
|
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||||
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||||
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
type_vocab_size (`int`, *optional*, defaults to 2):
|
||||||
|
The vocabulary size of the `token_type_ids` passed when calling [`YosoModel`].
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||||
|
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`.
|
||||||
|
use_expectation (*bool*, *optional*, defaults to *True*):
|
||||||
|
Whether or not to use YOSO Expectation. Overrides any effect of num_hash.
|
||||||
|
hash_code_len (`int`, *optional*, defaults to 9):
|
||||||
|
The length of hashes generated by the hash functions.
|
||||||
|
num_hash (`int`, *optional*, defaults to 64):
|
||||||
|
Number of hash functions used in [`YosoSelfAttention`].
|
||||||
|
conv_window (`int`, *optional*, defaults to None):
|
||||||
|
Kernel size of depth-wise convolution.
|
||||||
|
use_fast_hash (*bool*, *optional*, defaults to *False*):
|
||||||
|
Whether or not to use custom cuda kernels which perform fast random projection via hadamard transform.
|
||||||
|
lsh_backward (*bool*, *optional*, defaults to *True*):
|
||||||
|
Whether or not to perform backpropagation using Locality Sensitive Hashing.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import YosoModel, YosoConfig
|
||||||
|
|
||||||
|
>>> # Initializing a YOSO uw-madison/yoso-4096 style configuration
|
||||||
|
>>> configuration = YosoConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the uw-madison/yoso-4096 style configuration
|
||||||
|
>>> model = YosoModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
model_type = "yoso"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50265,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
type_vocab_size=1,
|
||||||
|
initializer_range=0.02,
|
||||||
|
layer_norm_eps=1e-12,
|
||||||
|
position_embedding_type="absolute",
|
||||||
|
use_expectation=True,
|
||||||
|
hash_code_len=9,
|
||||||
|
num_hash=64,
|
||||||
|
conv_window=None,
|
||||||
|
use_fast_hash=True,
|
||||||
|
lsh_backward=True,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.position_embedding_type = position_embedding_type
|
||||||
|
self.use_expectation = use_expectation
|
||||||
|
self.hash_code_len = hash_code_len
|
||||||
|
self.num_hash = num_hash
|
||||||
|
self.conv_window = conv_window
|
||||||
|
self.use_fast_hash = use_fast_hash
|
||||||
|
self.lsh_backward = lsh_backward
|
||||||
109
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
Normal file
109
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert YOSO checkpoints from the original repository. URL: https://github.com/mlpen/YOSO"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import YosoConfig, YosoForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
|
def rename_key(orig_key):
|
||||||
|
if "model" in orig_key:
|
||||||
|
orig_key = orig_key.replace("model.", "")
|
||||||
|
if "norm1" in orig_key:
|
||||||
|
orig_key = orig_key.replace("norm1", "attention.output.LayerNorm")
|
||||||
|
if "norm2" in orig_key:
|
||||||
|
orig_key = orig_key.replace("norm2", "output.LayerNorm")
|
||||||
|
if "norm" in orig_key:
|
||||||
|
orig_key = orig_key.replace("norm", "LayerNorm")
|
||||||
|
if "transformer" in orig_key:
|
||||||
|
layer_num = orig_key.split(".")[0].split("_")[-1]
|
||||||
|
orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}")
|
||||||
|
if "mha.attn" in orig_key:
|
||||||
|
orig_key = orig_key.replace("mha.attn", "attention.self")
|
||||||
|
if "mha" in orig_key:
|
||||||
|
orig_key = orig_key.replace("mha", "attention")
|
||||||
|
if "W_q" in orig_key:
|
||||||
|
orig_key = orig_key.replace("W_q", "self.query")
|
||||||
|
if "W_k" in orig_key:
|
||||||
|
orig_key = orig_key.replace("W_k", "self.key")
|
||||||
|
if "W_v" in orig_key:
|
||||||
|
orig_key = orig_key.replace("W_v", "self.value")
|
||||||
|
if "ff1" in orig_key:
|
||||||
|
orig_key = orig_key.replace("ff1", "intermediate.dense")
|
||||||
|
if "ff2" in orig_key:
|
||||||
|
orig_key = orig_key.replace("ff2", "output.dense")
|
||||||
|
if "ff" in orig_key:
|
||||||
|
orig_key = orig_key.replace("ff", "output.dense")
|
||||||
|
if "mlm_class" in orig_key:
|
||||||
|
orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder")
|
||||||
|
if "mlm" in orig_key:
|
||||||
|
orig_key = orig_key.replace("mlm", "cls.predictions.transform")
|
||||||
|
if "cls" not in orig_key:
|
||||||
|
orig_key = "yoso." + orig_key
|
||||||
|
|
||||||
|
return orig_key
|
||||||
|
|
||||||
|
|
||||||
|
def convert_checkpoint_helper(max_position_embeddings, orig_state_dict):
|
||||||
|
for key in orig_state_dict.copy().keys():
|
||||||
|
val = orig_state_dict.pop(key)
|
||||||
|
|
||||||
|
if ("pooler" in key) or ("sen_class" in key):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
orig_state_dict[rename_key(key)] = val
|
||||||
|
|
||||||
|
orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"]
|
||||||
|
orig_state_dict["yoso.embeddings.position_ids"] = torch.arange(max_position_embeddings).expand((1, -1)) + 2
|
||||||
|
|
||||||
|
return orig_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_yoso_checkpoint(checkpoint_path, yoso_config_file, pytorch_dump_path):
|
||||||
|
|
||||||
|
orig_state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
||||||
|
config = YosoConfig.from_json_file(yoso_config_file)
|
||||||
|
model = YosoForMaskedLM(config)
|
||||||
|
|
||||||
|
new_state_dict = convert_checkpoint_helper(config.max_position_embeddings, orig_state_dict)
|
||||||
|
|
||||||
|
print(model.load_state_dict(new_state_dict))
|
||||||
|
model.eval()
|
||||||
|
model.save_pretrained(pytorch_dump_path)
|
||||||
|
|
||||||
|
print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_model_path", default=None, type=str, required=True, help="Path to YOSO pytorch checkpoint."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The json file for YOSO model config.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_yoso_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path)
|
||||||
588
src/transformers/models/yoso/fast_lsh_cumulation.cu
Normal file
588
src/transformers/models/yoso/fast_lsh_cumulation.cu
Normal file
@@ -0,0 +1,588 @@
|
|||||||
|
// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation.cu
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include "fast_lsh_cumulation.h"
|
||||||
|
#include "fast_lsh_cumulation_cuda.h"
|
||||||
|
#include "common_cuda.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include <vector>
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
std::vector<at::Tensor> fast_hash_ver1_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_vector,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_vector,
|
||||||
|
int num_hash_f,
|
||||||
|
int hash_code_len,
|
||||||
|
bool use_cuda
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_size = query_vector.size(0);
|
||||||
|
int num_query = query_vector.size(1);
|
||||||
|
int num_key = key_vector.size(1);
|
||||||
|
int vector_dim = query_vector.size(2);
|
||||||
|
|
||||||
|
int num_hash_per_part = vector_dim / hash_code_len;
|
||||||
|
int num_part = max(1, ceil_divide(num_hash_f, num_hash_per_part));
|
||||||
|
|
||||||
|
at::Tensor Dmat = 2 * at::randint(0, 2, {batch_size, 3, num_part, vector_dim}, query_mask.options()) - 1;
|
||||||
|
at::Tensor query_hash_code = at::zeros({batch_size, num_query, num_hash_f}, query_mask.options());
|
||||||
|
at::Tensor key_hash_code = at::zeros({batch_size, num_key, num_hash_f}, key_mask.options());
|
||||||
|
|
||||||
|
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||||
|
float *query_vector_ptr = query_vector.data_ptr<float>();
|
||||||
|
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||||
|
float *key_vector_ptr = key_vector.data_ptr<float>();
|
||||||
|
|
||||||
|
int *Dmat_ptr = Dmat.data_ptr<int>();
|
||||||
|
|
||||||
|
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||||
|
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||||
|
|
||||||
|
if (use_cuda) {
|
||||||
|
{
|
||||||
|
dim3 threads(vector_dim);
|
||||||
|
dim3 blocks(num_part, num_query, batch_size);
|
||||||
|
int shared_mem = vector_dim * sizeof(float);
|
||||||
|
fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_vector_ptr,
|
||||||
|
Dmat_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_query,
|
||||||
|
vector_dim,
|
||||||
|
num_part,
|
||||||
|
num_hash_f,
|
||||||
|
hash_code_len
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(vector_dim);
|
||||||
|
dim3 blocks(num_part, num_key, batch_size);
|
||||||
|
int shared_mem = vector_dim * sizeof(float);
|
||||||
|
fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_vector_ptr,
|
||||||
|
Dmat_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_key,
|
||||||
|
vector_dim,
|
||||||
|
num_part,
|
||||||
|
num_hash_f,
|
||||||
|
hash_code_len
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {query_hash_code, key_hash_code};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_cumulation_ver1_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_size = query_hash_code.size(0);
|
||||||
|
int num_hash_f = query_hash_code.size(2);
|
||||||
|
|
||||||
|
int num_query = query_hash_code.size(1);
|
||||||
|
int num_key = key_hash_code.size(1);
|
||||||
|
int value_dim = value.size(2);
|
||||||
|
|
||||||
|
at::Tensor hashtable_value = at::empty({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
|
||||||
|
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||||
|
|
||||||
|
if (use_cuda) {
|
||||||
|
int threads_x = WARP_SIZE;
|
||||||
|
int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
|
||||||
|
int block_x_step1 = num_key / threads_y;
|
||||||
|
int block_x_step2 = num_query / threads_y;
|
||||||
|
int block_y = batch_size;
|
||||||
|
|
||||||
|
dim3 threads(threads_x, threads_y);
|
||||||
|
dim3 blocks_step1(block_x_step1, block_y);
|
||||||
|
dim3 blocks_step2(block_x_step2, block_y);
|
||||||
|
|
||||||
|
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||||
|
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||||
|
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||||
|
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||||
|
float *value_ptr = value.data_ptr<float>();
|
||||||
|
float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
|
||||||
|
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||||
|
|
||||||
|
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||||
|
|
||||||
|
cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
|
||||||
|
|
||||||
|
lsh_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
value_ptr,
|
||||||
|
hashtable_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_key,
|
||||||
|
value_dim,
|
||||||
|
value_offset
|
||||||
|
);
|
||||||
|
|
||||||
|
lsh_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
hashtable_value_ptr,
|
||||||
|
cumulation_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query,
|
||||||
|
value_dim,
|
||||||
|
value_offset
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return cumulation_value;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver1_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_size = query_hash_code.size(0);
|
||||||
|
int num_hash_f = query_hash_code.size(2);
|
||||||
|
|
||||||
|
int num_query = query_hash_code.size(1);
|
||||||
|
int num_key = key_hash_code.size(1);
|
||||||
|
int value_dim = value.size(2);
|
||||||
|
int weight_dim = query_weight.size(2);
|
||||||
|
|
||||||
|
at::Tensor hashtable_value = at::zeros({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
|
||||||
|
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||||
|
|
||||||
|
if (use_cuda) {
|
||||||
|
int threads_x = WARP_SIZE;
|
||||||
|
int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
|
||||||
|
int block_x_step1 = num_key / threads_y;
|
||||||
|
int block_x_step2 = num_query / threads_y;
|
||||||
|
int block_y = batch_size;
|
||||||
|
|
||||||
|
dim3 threads(threads_x, threads_y);
|
||||||
|
dim3 blocks_step1(block_x_step1, block_y);
|
||||||
|
dim3 blocks_step2(block_x_step2, block_y);
|
||||||
|
|
||||||
|
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||||
|
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||||
|
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||||
|
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||||
|
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||||
|
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||||
|
float *value_ptr = value.data_ptr<float>();
|
||||||
|
float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
|
||||||
|
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||||
|
|
||||||
|
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||||
|
for (int weight_idx = 0; weight_idx < weight_dim; weight_idx++) {
|
||||||
|
|
||||||
|
cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
|
||||||
|
|
||||||
|
lsh_weighted_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
key_weight_ptr,
|
||||||
|
value_ptr,
|
||||||
|
hashtable_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_key,
|
||||||
|
value_dim,
|
||||||
|
weight_dim,
|
||||||
|
value_offset,
|
||||||
|
weight_idx
|
||||||
|
);
|
||||||
|
|
||||||
|
lsh_weighted_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
query_weight_ptr,
|
||||||
|
hashtable_value_ptr,
|
||||||
|
cumulation_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query,
|
||||||
|
value_dim,
|
||||||
|
weight_dim,
|
||||||
|
value_offset,
|
||||||
|
weight_idx
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return cumulation_value;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver2_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_size = query_hash_code.size(0);
|
||||||
|
int num_hash_f = query_hash_code.size(2);
|
||||||
|
|
||||||
|
int num_query = query_hash_code.size(1);
|
||||||
|
int num_key = key_hash_code.size(1);
|
||||||
|
int value_dim = value.size(2);
|
||||||
|
int weight_dim = query_weight.size(2);
|
||||||
|
|
||||||
|
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
||||||
|
at::Tensor key_sorted_idxes = at::zeros({batch_size, num_hash_f, num_key}, query_hash_code.options());
|
||||||
|
at::Tensor query_info = at::zeros({batch_size, num_query, 2, num_hash_f}, query_hash_code.options());
|
||||||
|
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||||
|
|
||||||
|
if (use_cuda) {
|
||||||
|
|
||||||
|
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||||
|
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||||
|
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||||
|
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||||
|
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||||
|
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||||
|
float *value_ptr = value.data_ptr<float>();
|
||||||
|
|
||||||
|
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
||||||
|
int *key_sorted_idxes_ptr = key_sorted_idxes.data_ptr<int>();
|
||||||
|
int *query_info_ptr = query_info.data_ptr<int>();
|
||||||
|
|
||||||
|
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||||
|
|
||||||
|
{
|
||||||
|
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||||
|
dim3 blocks_step13(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||||
|
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
||||||
|
dim3 blocks_step2(num_hash_f, batch_size);
|
||||||
|
int shared_mem = hashtable_capacity * sizeof(float);
|
||||||
|
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_key
|
||||||
|
);
|
||||||
|
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
||||||
|
count_sort_table_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity
|
||||||
|
);
|
||||||
|
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
key_sorted_idxes_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_key
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||||
|
dim3 blocks(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||||
|
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
query_info_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
||||||
|
dim3 blocks(num_query, num_hash_f, batch_size);
|
||||||
|
int shared_mem = (weight_dim + WARP_SIZE) * sizeof(float);
|
||||||
|
lsh_weighted_cumulation_ver2_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_info_ptr,
|
||||||
|
key_sorted_idxes_ptr,
|
||||||
|
query_weight_ptr,
|
||||||
|
key_weight_ptr,
|
||||||
|
value_ptr,
|
||||||
|
cumulation_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
num_query,
|
||||||
|
num_key,
|
||||||
|
value_dim,
|
||||||
|
weight_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cumulation_value;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver3_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_size = query_hash_code.size(0);
|
||||||
|
int num_hash_f = query_hash_code.size(2);
|
||||||
|
|
||||||
|
int num_query = query_hash_code.size(1);
|
||||||
|
int num_key = key_hash_code.size(1);
|
||||||
|
int value_dim = value.size(2);
|
||||||
|
int weight_dim = query_weight.size(2);
|
||||||
|
|
||||||
|
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
||||||
|
at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
|
||||||
|
at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
|
||||||
|
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||||
|
|
||||||
|
if (use_cuda) {
|
||||||
|
|
||||||
|
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||||
|
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||||
|
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||||
|
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||||
|
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||||
|
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||||
|
float *value_ptr = value.data_ptr<float>();
|
||||||
|
|
||||||
|
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
||||||
|
int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
|
||||||
|
int *key_info_ptr = key_info.data_ptr<int>();
|
||||||
|
|
||||||
|
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||||
|
|
||||||
|
{
|
||||||
|
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||||
|
dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||||
|
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
||||||
|
dim3 blocks_step2(num_hash_f, batch_size);
|
||||||
|
int shared_mem = hashtable_capacity * sizeof(float);
|
||||||
|
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query
|
||||||
|
);
|
||||||
|
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
||||||
|
count_sort_table_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity
|
||||||
|
);
|
||||||
|
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
query_sorted_idxes_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||||
|
dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||||
|
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
key_info_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_key
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
||||||
|
dim3 blocks(num_key, num_hash_f, batch_size);
|
||||||
|
int shared_mem = (weight_dim + value_dim + WARP_SIZE) * sizeof(float);
|
||||||
|
lsh_weighted_cumulation_ver3_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||||
|
query_sorted_idxes_ptr,
|
||||||
|
key_mask_ptr,
|
||||||
|
key_info_ptr,
|
||||||
|
query_weight_ptr,
|
||||||
|
key_weight_ptr,
|
||||||
|
value_ptr,
|
||||||
|
cumulation_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
num_query,
|
||||||
|
num_key,
|
||||||
|
value_dim,
|
||||||
|
weight_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cumulation_value;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver4_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_size = query_hash_code.size(0);
|
||||||
|
int num_hash_f = query_hash_code.size(2);
|
||||||
|
|
||||||
|
int num_query = query_hash_code.size(1);
|
||||||
|
int num_key = key_hash_code.size(1);
|
||||||
|
int value_dim = value.size(2);
|
||||||
|
int weight_dim = query_weight.size(2);
|
||||||
|
|
||||||
|
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
||||||
|
at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
|
||||||
|
at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
|
||||||
|
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||||
|
|
||||||
|
if (use_cuda) {
|
||||||
|
|
||||||
|
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||||
|
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||||
|
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||||
|
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||||
|
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||||
|
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||||
|
float *value_ptr = value.data_ptr<float>();
|
||||||
|
|
||||||
|
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
||||||
|
int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
|
||||||
|
int *key_info_ptr = key_info.data_ptr<int>();
|
||||||
|
|
||||||
|
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||||
|
|
||||||
|
{
|
||||||
|
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||||
|
dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||||
|
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
||||||
|
dim3 blocks_step2(num_hash_f, batch_size);
|
||||||
|
int shared_mem = hashtable_capacity * sizeof(float);
|
||||||
|
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query
|
||||||
|
);
|
||||||
|
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
||||||
|
count_sort_table_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity
|
||||||
|
);
|
||||||
|
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||||
|
query_mask_ptr,
|
||||||
|
query_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
query_sorted_idxes_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_query
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||||
|
dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||||
|
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
||||||
|
key_mask_ptr,
|
||||||
|
key_hash_code_ptr,
|
||||||
|
count_sort_table_ptr,
|
||||||
|
key_info_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
hashtable_capacity,
|
||||||
|
num_key
|
||||||
|
);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
||||||
|
dim3 blocks(num_key, batch_size);
|
||||||
|
int shared_mem = (weight_dim + value_dim + 2 * num_hash_f) * sizeof(float);
|
||||||
|
lsh_weighted_cumulation_ver4_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||||
|
query_sorted_idxes_ptr,
|
||||||
|
key_mask_ptr,
|
||||||
|
key_info_ptr,
|
||||||
|
query_weight_ptr,
|
||||||
|
key_weight_ptr,
|
||||||
|
value_ptr,
|
||||||
|
cumulation_value_ptr,
|
||||||
|
batch_size,
|
||||||
|
num_hash_f,
|
||||||
|
num_query,
|
||||||
|
num_key,
|
||||||
|
value_dim,
|
||||||
|
weight_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cumulation_value;
|
||||||
|
|
||||||
|
}
|
||||||
71
src/transformers/models/yoso/fast_lsh_cumulation.h
Normal file
71
src/transformers/models/yoso/fast_lsh_cumulation.h
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
std::vector<at::Tensor> fast_hash_ver1_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_vector,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_vector,
|
||||||
|
int num_hash_f,
|
||||||
|
int hash_code_len,
|
||||||
|
bool use_cuda
|
||||||
|
);
|
||||||
|
|
||||||
|
at::Tensor lsh_cumulation_ver1_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
);
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver1_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
);
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver2_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
);
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver3_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
);
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation_ver4_kernel(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_hash_code,
|
||||||
|
at::Tensor query_weight,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_hash_code,
|
||||||
|
at::Tensor key_weight,
|
||||||
|
at::Tensor value,
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda
|
||||||
|
);
|
||||||
825
src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu
Normal file
825
src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu
Normal file
@@ -0,0 +1,825 @@
|
|||||||
|
// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation_cuda.cu
|
||||||
|
|
||||||
|
#include "fast_lsh_cumulation_cuda.h"
|
||||||
|
#include "common_cuda_device.h"
|
||||||
|
#include "common_cuda.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include <stdio.h>
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline __device__ void fast_hadamard_transform(float *vector_buffer, int vector_dim, int dim_idx) {
|
||||||
|
int stride = vector_dim / 2;
|
||||||
|
while (stride > (WARP_SIZE / 2)) {
|
||||||
|
__syncthreads();
|
||||||
|
int sign = 1 - ((dim_idx / stride) % 2) * 2;
|
||||||
|
float val1 = vector_buffer[dim_idx];
|
||||||
|
float val2 = vector_buffer[dim_idx + sign * stride];
|
||||||
|
__syncthreads();
|
||||||
|
vector_buffer[dim_idx] = float(sign) * val1 + val2;
|
||||||
|
stride = stride / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
float val = vector_buffer[dim_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (stride = (WARP_SIZE / 2); stride > 0; stride = stride / 2) {
|
||||||
|
int sign = 1 - ((dim_idx / stride) % 2) * 2;
|
||||||
|
val = float(sign) * val + __shfl_xor_sync(FULL_MASK, val, stride);
|
||||||
|
}
|
||||||
|
vector_buffer[dim_idx] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void fast_hash_ver1_cuda_kernel(
|
||||||
|
int *mask, // [batch_size, num_vector]
|
||||||
|
float *vector, // [batch_size, num_vector, vector_dim]
|
||||||
|
int *Dmat, // [batch_size, 3, num_part, vector_dim]
|
||||||
|
int *hash_code, // [batch_size, num_vector, num_hash_f]
|
||||||
|
int batch_size,
|
||||||
|
int num_vector,
|
||||||
|
int vector_dim,
|
||||||
|
int num_part,
|
||||||
|
int num_hash_f,
|
||||||
|
int hash_code_len
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.z;
|
||||||
|
int vector_idx = blockIdx.y;
|
||||||
|
int part_idx = blockIdx.x;
|
||||||
|
|
||||||
|
int dim_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__vector_idx = batch_idx * num_vector + vector_idx;
|
||||||
|
if (mask[batch_idx__vector_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern __shared__ float buffer[];
|
||||||
|
float *vector_buffer = buffer;
|
||||||
|
|
||||||
|
vector_buffer[dim_idx] = vector[batch_idx__vector_idx * vector_dim + dim_idx];
|
||||||
|
|
||||||
|
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 0) * num_part + part_idx) * vector_dim + dim_idx];
|
||||||
|
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
||||||
|
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 1) * num_part + part_idx) * vector_dim + dim_idx];
|
||||||
|
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
||||||
|
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 2) * num_part + part_idx) * vector_dim + dim_idx];
|
||||||
|
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
||||||
|
|
||||||
|
int num_hash_per_part = vector_dim / hash_code_len;
|
||||||
|
if (hash_code_len == 8 || hash_code_len == 16) {
|
||||||
|
int code = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
|
||||||
|
for (int offset = 1; offset < hash_code_len; offset = offset * 2) {
|
||||||
|
code += __shfl_xor_sync(FULL_MASK, code, offset);
|
||||||
|
}
|
||||||
|
if (dim_idx % hash_code_len == 0) {
|
||||||
|
int hash_f_idx = part_idx * num_hash_per_part + dim_idx / hash_code_len;
|
||||||
|
if (hash_f_idx < num_hash_f) {
|
||||||
|
hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vector_buffer[dim_idx] = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
|
||||||
|
__syncthreads();
|
||||||
|
if (dim_idx < num_hash_per_part) {
|
||||||
|
int code = 0;
|
||||||
|
for (int i = 0; i < hash_code_len; i++) {
|
||||||
|
code += vector_buffer[dim_idx * hash_code_len + i];
|
||||||
|
}
|
||||||
|
int hash_f_idx = part_idx * num_hash_per_part + dim_idx;
|
||||||
|
if (hash_f_idx < num_hash_f) {
|
||||||
|
hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int offset_warp
|
||||||
|
) {
|
||||||
|
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
if (key_mask[batch_idx__key_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_hash_f > WARP_SIZE) {
|
||||||
|
float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||||
|
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||||
|
int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||||
|
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||||
|
int warp_hashcode = 0;
|
||||||
|
if (warp_thread_idx < num_hash_f) {
|
||||||
|
warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
|
||||||
|
}
|
||||||
|
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||||
|
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_query,
|
||||||
|
int value_dim,
|
||||||
|
int offset_warp
|
||||||
|
) {
|
||||||
|
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
if (query_mask[batch_idx__query_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_hash_f > WARP_SIZE) {
|
||||||
|
float warp_value = 0;
|
||||||
|
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||||
|
int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||||
|
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
|
||||||
|
} else {
|
||||||
|
float warp_value = 0;
|
||||||
|
int warp_hashcode = 0;
|
||||||
|
if (warp_thread_idx < num_hash_f) {
|
||||||
|
warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
|
||||||
|
}
|
||||||
|
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||||
|
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||||
|
}
|
||||||
|
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim,
|
||||||
|
int offset_warp,
|
||||||
|
int weight_idx
|
||||||
|
) {
|
||||||
|
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
if (key_mask[batch_idx__key_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_hash_f > WARP_SIZE) {
|
||||||
|
float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||||
|
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||||
|
int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||||
|
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||||
|
int warp_hashcode = 0;
|
||||||
|
if (warp_thread_idx < num_hash_f) {
|
||||||
|
warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
|
||||||
|
}
|
||||||
|
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||||
|
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_query,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim,
|
||||||
|
int offset_warp,
|
||||||
|
int weight_idx
|
||||||
|
) {
|
||||||
|
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
if (query_mask[batch_idx__query_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_hash_f > WARP_SIZE) {
|
||||||
|
float warp_value = 0;
|
||||||
|
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||||
|
int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||||
|
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
|
||||||
|
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
|
||||||
|
} else {
|
||||||
|
float warp_value = 0;
|
||||||
|
int warp_hashcode = 0;
|
||||||
|
if (warp_thread_idx < num_hash_f) {
|
||||||
|
warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
|
||||||
|
}
|
||||||
|
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||||
|
int current_hashcode = warp_hashcode;
|
||||||
|
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||||
|
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||||
|
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||||
|
}
|
||||||
|
float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
|
||||||
|
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void count_sort_step1_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
int hash_f_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
if (key_mask[batch_idx__key_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
|
||||||
|
atomicAdd(&count_sort_table[(batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code], 1);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void count_sort_step2_cuda_kernel(
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int hash_f_idx = blockIdx.x;
|
||||||
|
|
||||||
|
int num_threads = blockDim.x;
|
||||||
|
int thread_id = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
||||||
|
|
||||||
|
extern __shared__ float buffer[];
|
||||||
|
int *table_buffer = (int*)buffer;
|
||||||
|
|
||||||
|
if (thread_id == 0) {
|
||||||
|
table_buffer[0] = 0;
|
||||||
|
}
|
||||||
|
copy_data<int>(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], &table_buffer[1], hashtable_capacity - 1, num_threads, thread_id);
|
||||||
|
|
||||||
|
for (int table_idx_start = 0; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + num_threads) {
|
||||||
|
int thread_value = table_buffer[table_idx_start + thread_id];
|
||||||
|
int next_thread_value = 0;
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
next_thread_value = __shfl_up_sync(FULL_MASK, thread_value, offset);
|
||||||
|
if (thread_id % WARP_SIZE >= offset) {
|
||||||
|
thread_value = thread_value + next_thread_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
table_buffer[table_idx_start + thread_id] = thread_value;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (hashtable_capacity > WARP_SIZE) {
|
||||||
|
if (thread_id < WARP_SIZE) {
|
||||||
|
for (int table_idx_start = WARP_SIZE; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + WARP_SIZE) {
|
||||||
|
table_buffer[table_idx_start + thread_id] += table_buffer[table_idx_start - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_data<int>(table_buffer, &count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], hashtable_capacity, num_threads, thread_id);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void count_sort_step3_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
int hash_f_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
if (key_mask[batch_idx__key_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
||||||
|
|
||||||
|
int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
|
||||||
|
int sort_idx = atomicAdd(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity + hash_code], 1);
|
||||||
|
key_sorted_idxes[batch_idx__hash_f_idx * num_key + sort_idx] = key_idx;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void extract_query_info_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_query
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
int hash_f_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
if (query_mask[batch_idx__query_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int hash_code = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_idx];
|
||||||
|
int batch_idx__hash_f_idx__hash_code = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code;
|
||||||
|
|
||||||
|
int key_offset = select(hash_code == 0, 0, count_sort_table[batch_idx__hash_f_idx__hash_code - 1]);
|
||||||
|
int key_count = count_sort_table[batch_idx__hash_f_idx__hash_code] - key_offset;
|
||||||
|
|
||||||
|
query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx] = key_offset;
|
||||||
|
query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx] = key_count;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||||
|
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int num_query,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.z;
|
||||||
|
int hash_f_idx = blockIdx.y;
|
||||||
|
int query_idx = blockIdx.x;
|
||||||
|
|
||||||
|
int num_threads = blockDim.y * blockDim.x;
|
||||||
|
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
int num_warps = blockDim.y;
|
||||||
|
int warp_idx = threadIdx.y;
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
if (query_mask[batch_idx__query_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int key_offset = query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx];
|
||||||
|
int key_count = query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx];
|
||||||
|
|
||||||
|
if (key_count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern __shared__ float buffer[];
|
||||||
|
|
||||||
|
if (key_count == 1) {
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
int key_idx = key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset];
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
float weight = 0;
|
||||||
|
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||||
|
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||||
|
float val = query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||||
|
}
|
||||||
|
weight = weight + val;
|
||||||
|
}
|
||||||
|
weight = weight / float(num_hash_f);
|
||||||
|
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||||
|
int value_dim_idx = value_offset + warp_thread_idx;
|
||||||
|
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
||||||
|
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float *weight_buffer = buffer;
|
||||||
|
int *key_idxes_buffer = (int*)&buffer[weight_dim];
|
||||||
|
|
||||||
|
copy_data_nonblocking<float>(&query_weight[batch_idx__query_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
||||||
|
|
||||||
|
while (key_count > 0) {
|
||||||
|
int work_size = min(WARP_SIZE, key_count);
|
||||||
|
copy_data_nonblocking<int>(&key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset], key_idxes_buffer, work_size, num_threads, thread_id);
|
||||||
|
__syncthreads();
|
||||||
|
for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
|
||||||
|
int work_idx = work_offset + warp_idx;
|
||||||
|
if (work_idx < key_count) {
|
||||||
|
int key_idx = key_idxes_buffer[work_idx];
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
float weight = 0;
|
||||||
|
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||||
|
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||||
|
float val = weight_buffer[weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||||
|
}
|
||||||
|
weight = weight + val;
|
||||||
|
}
|
||||||
|
weight = weight / float(num_hash_f);
|
||||||
|
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||||
|
int value_dim_idx = value_offset + warp_thread_idx;
|
||||||
|
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
||||||
|
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
key_count = key_count - work_size;
|
||||||
|
key_offset = key_offset + work_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
|
||||||
|
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int num_query,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.z;
|
||||||
|
int hash_f_idx = blockIdx.y;
|
||||||
|
int key_idx = blockIdx.x;
|
||||||
|
|
||||||
|
int num_threads = blockDim.y * blockDim.x;
|
||||||
|
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
int num_warps = blockDim.y;
|
||||||
|
int warp_idx = threadIdx.y;
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
if (key_mask[batch_idx__key_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int query_offset = key_info[batch_idx__key_idx * 2 * num_hash_f + hash_f_idx];
|
||||||
|
int query_count = key_info[(batch_idx__key_idx * 2 + 1) * num_hash_f + hash_f_idx];
|
||||||
|
|
||||||
|
if (query_count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern __shared__ float buffer[];
|
||||||
|
|
||||||
|
if (query_count == 1) {
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
int query_idx = query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset];
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
float weight = 0;
|
||||||
|
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||||
|
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||||
|
float val = key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||||
|
}
|
||||||
|
weight = weight + val;
|
||||||
|
}
|
||||||
|
weight = weight / float(num_hash_f);
|
||||||
|
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||||
|
int value_dim_idx = value_offset + warp_thread_idx;
|
||||||
|
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
||||||
|
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
float *weight_buffer = buffer;
|
||||||
|
float *value_buffer = &buffer[weight_dim];
|
||||||
|
int *query_idxes_buffer = (int*)&buffer[weight_dim + value_dim];
|
||||||
|
|
||||||
|
copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
||||||
|
copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
|
||||||
|
|
||||||
|
while (query_count > 0) {
|
||||||
|
int work_size = min(WARP_SIZE, query_count);
|
||||||
|
copy_data_nonblocking<int>(&query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset], query_idxes_buffer, work_size, num_threads, thread_id);
|
||||||
|
__syncthreads();
|
||||||
|
for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
|
||||||
|
int work_idx = work_offset + warp_idx;
|
||||||
|
if (work_idx < query_count) {
|
||||||
|
int query_idx = query_idxes_buffer[work_idx];
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
float weight = 0;
|
||||||
|
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||||
|
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||||
|
float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||||
|
}
|
||||||
|
weight = weight + val;
|
||||||
|
}
|
||||||
|
weight = weight / float(num_hash_f);
|
||||||
|
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||||
|
int value_dim_idx = value_offset + warp_thread_idx;
|
||||||
|
float val = value_buffer[value_dim_idx];
|
||||||
|
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
query_count = query_count - work_size;
|
||||||
|
query_offset = query_offset + work_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
|
||||||
|
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int num_query,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim
|
||||||
|
) {
|
||||||
|
|
||||||
|
int batch_idx = blockIdx.y;
|
||||||
|
int key_idx = blockIdx.x;
|
||||||
|
|
||||||
|
int num_threads = blockDim.y * blockDim.x;
|
||||||
|
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
int num_warps = blockDim.y;
|
||||||
|
int warp_idx = threadIdx.y;
|
||||||
|
int warp_thread_idx = threadIdx.x;
|
||||||
|
|
||||||
|
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||||
|
if (key_mask[batch_idx__key_idx] == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
extern __shared__ float buffer[];
|
||||||
|
float *weight_buffer = buffer;
|
||||||
|
float *value_buffer = &buffer[weight_dim];
|
||||||
|
int *key_info_buffer = (int*)&buffer[weight_dim + value_dim];
|
||||||
|
|
||||||
|
copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
||||||
|
copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
|
||||||
|
copy_data_nonblocking<int>(&key_info[batch_idx__key_idx * 2 * num_hash_f], key_info_buffer, 2 * num_hash_f, num_threads, thread_id);
|
||||||
|
|
||||||
|
int *query_offset_buffer = key_info_buffer;
|
||||||
|
int *query_count_buffer = &key_info_buffer[num_hash_f];
|
||||||
|
|
||||||
|
const int hashtable_size = 1024 + OPTIMAL_THREADS_PER_BLOCK;
|
||||||
|
__shared__ int hashtable_query[hashtable_size];
|
||||||
|
__shared__ int hashtable_count[hashtable_size];
|
||||||
|
__shared__ int inserted_query[hashtable_size];
|
||||||
|
__shared__ int query_counter[1];
|
||||||
|
|
||||||
|
int hash_f_idx_base = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
|
||||||
|
init_buffer_nonblocking<int>(EMPTY_VALUE, hashtable_query, hashtable_size, num_threads, thread_id);
|
||||||
|
init_buffer_nonblocking<int>(0, hashtable_count, hashtable_size, num_threads, thread_id);
|
||||||
|
init_buffer_nonblocking<int>(EMPTY_VALUE, inserted_query, hashtable_size, num_threads, thread_id);
|
||||||
|
init_buffer_nonblocking<int>(0, query_counter, 1, num_threads, thread_id);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
while (hash_f_idx_base < num_hash_f) {
|
||||||
|
|
||||||
|
int hash_f_idx = hash_f_idx_base + warp_idx;
|
||||||
|
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
||||||
|
|
||||||
|
int stop_flag = 0;
|
||||||
|
|
||||||
|
int query_offset = query_offset_buffer[hash_f_idx];
|
||||||
|
int query_count = query_count_buffer[hash_f_idx];
|
||||||
|
|
||||||
|
while (query_count > 0) {
|
||||||
|
|
||||||
|
int work_size = min(query_count, WARP_SIZE);
|
||||||
|
|
||||||
|
// try inserting query to set and check whether the query is new
|
||||||
|
int found_new_query = 0;
|
||||||
|
int query_idx = -1;
|
||||||
|
if (warp_thread_idx < work_size) {
|
||||||
|
query_idx = query_sorted_idxes[batch_idx__hash_f_idx * num_query + query_offset + warp_thread_idx];
|
||||||
|
int slot = set_insert<int>(hashtable_query, hashtable_size, query_idx);
|
||||||
|
if (slot >= 0) {
|
||||||
|
found_new_query = atomicAdd(&hashtable_count[slot], 1) == 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute cumulative offset
|
||||||
|
int position_offset = found_new_query;
|
||||||
|
int next_position_offset = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
next_position_offset = __shfl_up_sync(FULL_MASK, position_offset, offset);
|
||||||
|
if (thread_id % WARP_SIZE >= offset) {
|
||||||
|
position_offset = position_offset + next_position_offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the inserted query list end index
|
||||||
|
int inserted_query_base = 0;
|
||||||
|
if (thread_id % WARP_SIZE == WARP_SIZE - 1) {
|
||||||
|
inserted_query_base = atomicAdd(query_counter, position_offset);
|
||||||
|
}
|
||||||
|
inserted_query_base = __shfl_sync(FULL_MASK, inserted_query_base, WARP_SIZE - 1);
|
||||||
|
|
||||||
|
// insert new queries to list
|
||||||
|
int insert_idx = inserted_query_base + position_offset - 1;
|
||||||
|
if (found_new_query) {
|
||||||
|
inserted_query[insert_idx] = query_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove inserted queries from list
|
||||||
|
query_offset_buffer[hash_f_idx] += work_size;
|
||||||
|
query_count_buffer[hash_f_idx] -= work_size;
|
||||||
|
query_offset += work_size;
|
||||||
|
query_count -= work_size;
|
||||||
|
|
||||||
|
// if list is almost full, stop inserting
|
||||||
|
if (inserted_query_base + OPTIMAL_THREADS_PER_BLOCK > hashtable_size) {
|
||||||
|
stop_flag = 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stop_flag) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
hash_f_idx_base = hash_f_idx_base + num_warps;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int num_distint_query = query_counter[0];
|
||||||
|
|
||||||
|
if (num_distint_query > 0) {
|
||||||
|
for (int idx_base = 0; idx_base < num_distint_query; idx_base = idx_base + num_warps) {
|
||||||
|
int idx = idx_base + warp_idx;
|
||||||
|
if (idx < num_distint_query) {
|
||||||
|
int query_idx = inserted_query[idx];
|
||||||
|
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||||
|
|
||||||
|
int slot = set_lookup<int>(hashtable_query, hashtable_size, query_idx);
|
||||||
|
int duplicate_count = hashtable_count[slot];
|
||||||
|
|
||||||
|
float weight = 0;
|
||||||
|
for (int weight_idx_base = 0; weight_idx_base < weight_dim; weight_idx_base = weight_idx_base + WARP_SIZE) {
|
||||||
|
int weight_dim_idx = weight_idx_base + warp_thread_idx;
|
||||||
|
float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||||
|
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||||
|
}
|
||||||
|
weight = weight + val;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight = (float)duplicate_count * weight / float(num_hash_f);
|
||||||
|
|
||||||
|
for (int value_idx_base = 0; value_idx_base < value_dim; value_idx_base = value_idx_base + WARP_SIZE) {
|
||||||
|
int value_dim_idx = value_idx_base + warp_thread_idx;
|
||||||
|
float val = value_buffer[value_dim_idx];
|
||||||
|
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
// all computation is completed if num_distint_query == 0
|
||||||
|
break;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
157
src/transformers/models/yoso/fast_lsh_cumulation_cuda.h
Normal file
157
src/transformers/models/yoso/fast_lsh_cumulation_cuda.h
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
__global__ void fast_hash_ver1_cuda_kernel(
|
||||||
|
int *mask, // [batch_size, num_vector]
|
||||||
|
float *vector, // [batch_size, num_vector, vector_dim]
|
||||||
|
int *Dmat, // [3, num_part, vector_dim]
|
||||||
|
int *hash_code, // [batch_size, num_vector, num_hash_f]
|
||||||
|
int batch_size,
|
||||||
|
int num_vector,
|
||||||
|
int vector_dim,
|
||||||
|
int num_part,
|
||||||
|
int num_hash_f,
|
||||||
|
int hash_code_len
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int offset_warp
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_query,
|
||||||
|
int value_dim,
|
||||||
|
int offset_warp
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim,
|
||||||
|
int offset_warp,
|
||||||
|
int weight_idx
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_query,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim,
|
||||||
|
int offset_warp,
|
||||||
|
int weight_idx
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void count_sort_step1_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void count_sort_step2_cuda_kernel(
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void count_sort_step3_cuda_kernel(
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_key
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void extract_query_info_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||||
|
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int hashtable_capacity,
|
||||||
|
int num_query
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
|
||||||
|
int *query_mask, // [batch_size, num_query]
|
||||||
|
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||||
|
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int num_query,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
|
||||||
|
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int num_query,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim
|
||||||
|
);
|
||||||
|
|
||||||
|
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
|
||||||
|
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||||
|
int *key_mask, // [batch_size, num_key]
|
||||||
|
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||||
|
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
float *value, // [batch_size, num_key, value_dim]
|
||||||
|
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||||
|
int batch_size,
|
||||||
|
int num_hash_f,
|
||||||
|
int num_query,
|
||||||
|
int num_key,
|
||||||
|
int value_dim,
|
||||||
|
int weight_dim
|
||||||
|
);
|
||||||
128
src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp
Normal file
128
src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include "fast_lsh_cumulation.h"
|
||||||
|
#include "common_cuda.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
std::vector<at::Tensor> fast_hash(
|
||||||
|
at::Tensor query_mask,
|
||||||
|
at::Tensor query_vector,
|
||||||
|
at::Tensor key_mask,
|
||||||
|
at::Tensor key_vector,
|
||||||
|
int num_hash_f,
|
||||||
|
int hash_code_len,
|
||||||
|
bool use_cuda,
|
||||||
|
int version
|
||||||
|
) {
|
||||||
|
return fast_hash_ver1_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_vector,
|
||||||
|
key_mask,
|
||||||
|
key_vector,
|
||||||
|
num_hash_f,
|
||||||
|
hash_code_len,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_cumulation(
|
||||||
|
at::Tensor query_mask, // [batch_size, num_query]
|
||||||
|
at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
at::Tensor key_mask, // [batch_size, num_key]
|
||||||
|
at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
at::Tensor value, // [batch_size, num_key, value_dim]
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda,
|
||||||
|
int version
|
||||||
|
) {
|
||||||
|
return lsh_cumulation_ver1_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_hash_code,
|
||||||
|
key_mask,
|
||||||
|
key_hash_code,
|
||||||
|
value,
|
||||||
|
hashtable_capacity,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor lsh_weighted_cumulation(
|
||||||
|
at::Tensor query_mask, // [batch_size, num_query]
|
||||||
|
at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||||
|
at::Tensor query_weight, // [batch_size, num_query, weight_dim]
|
||||||
|
at::Tensor key_mask, // [batch_size, num_key]
|
||||||
|
at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||||
|
at::Tensor key_weight, // [batch_size, num_key, weight_dim]
|
||||||
|
at::Tensor value, // [batch_size, num_key, value_dim]
|
||||||
|
int hashtable_capacity,
|
||||||
|
bool use_cuda,
|
||||||
|
int version
|
||||||
|
) {
|
||||||
|
if (version == 1) {
|
||||||
|
return lsh_weighted_cumulation_ver1_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_hash_code,
|
||||||
|
query_weight,
|
||||||
|
key_mask,
|
||||||
|
key_hash_code,
|
||||||
|
key_weight,
|
||||||
|
value,
|
||||||
|
hashtable_capacity,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
} else if (version == 2) {
|
||||||
|
return lsh_weighted_cumulation_ver2_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_hash_code,
|
||||||
|
query_weight,
|
||||||
|
key_mask,
|
||||||
|
key_hash_code,
|
||||||
|
key_weight,
|
||||||
|
value,
|
||||||
|
hashtable_capacity,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
} else if (version == 3) {
|
||||||
|
return lsh_weighted_cumulation_ver3_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_hash_code,
|
||||||
|
query_weight,
|
||||||
|
key_mask,
|
||||||
|
key_hash_code,
|
||||||
|
key_weight,
|
||||||
|
value,
|
||||||
|
hashtable_capacity,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
} else if (version == 4) {
|
||||||
|
return lsh_weighted_cumulation_ver4_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_hash_code,
|
||||||
|
query_weight,
|
||||||
|
key_mask,
|
||||||
|
key_hash_code,
|
||||||
|
key_weight,
|
||||||
|
value,
|
||||||
|
hashtable_capacity,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return lsh_weighted_cumulation_ver3_kernel(
|
||||||
|
query_mask,
|
||||||
|
query_hash_code,
|
||||||
|
query_weight,
|
||||||
|
key_mask,
|
||||||
|
key_hash_code,
|
||||||
|
key_weight,
|
||||||
|
value,
|
||||||
|
hashtable_capacity,
|
||||||
|
use_cuda
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
|
||||||
|
m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
|
||||||
|
m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
|
||||||
|
}
|
||||||
1324
src/transformers/models/yoso/modeling_yoso.py
Normal file
1324
src/transformers/models/yoso/modeling_yoso.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4060,6 +4060,65 @@ def load_tf_weights_in_xlnet(*args, **kwargs):
|
|||||||
requires_backends(load_tf_weights_in_xlnet, ["torch"])
|
requires_backends(load_tf_weights_in_xlnet, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class YosoForMaskedLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoForMultipleChoice(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoForSequenceClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoForTokenClassification(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoLayer(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class YosoPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Adafactor(metaclass=DummyObject):
|
class Adafactor(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
401
tests/test_modeling_yoso.py
Normal file
401
tests/test_modeling_yoso.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Testing suite for the PyTorch YOSO model. """
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tests.test_modeling_common import floats_tensor
|
||||||
|
from transformers import YosoConfig, is_torch_available
|
||||||
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
YosoForMaskedLM,
|
||||||
|
YosoForMultipleChoice,
|
||||||
|
YosoForQuestionAnswering,
|
||||||
|
YosoForSequenceClassification,
|
||||||
|
YosoForTokenClassification,
|
||||||
|
YosoModel,
|
||||||
|
)
|
||||||
|
from transformers.models.yoso.modeling_yoso import YOSO_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
|
class YosoModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return YosoConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
is_decoder=False,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
config.is_decoder = True
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = YosoModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
result = model(input_ids, token_type_ids=token_type_ids)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_model_as_decoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = YosoModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_masked_lm(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = YosoForMaskedLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_for_question_answering(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = YosoForQuestionAnswering(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
start_positions=sequence_labels,
|
||||||
|
end_positions=sequence_labels,
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||||
|
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||||
|
|
||||||
|
def create_and_check_for_sequence_classification(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = YosoForSequenceClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||||
|
|
||||||
|
def create_and_check_for_token_classification(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = YosoForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||||
|
|
||||||
|
def create_and_check_for_multiple_choice(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = YosoForMultipleChoice(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
result = model(
|
||||||
|
multiple_choice_inputs_ids,
|
||||||
|
attention_mask=multiple_choice_input_mask,
|
||||||
|
token_type_ids=multiple_choice_token_type_ids,
|
||||||
|
labels=choice_labels,
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class YosoModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
YosoModel,
|
||||||
|
YosoForMaskedLM,
|
||||||
|
YosoForMultipleChoice,
|
||||||
|
YosoForQuestionAnswering,
|
||||||
|
YosoForSequenceClassification,
|
||||||
|
YosoForTokenClassification,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
test_pruning = False
|
||||||
|
test_headmasking = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
all_generative_model_classes = ()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = YosoModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=YosoConfig, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_various_embeddings(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||||
|
config_and_inputs[0].position_embedding_type = type
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_masked_lm(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_question_answering(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_sequence_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_token_classification(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in YOSO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
model = YosoModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class YosoModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_inference_no_head(self):
|
||||||
|
model = YosoModel.from_pretrained("uw-madison/yoso-4096")
|
||||||
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 6, 768))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[[-0.0611, 0.1242, 0.0840], [0.0280, -0.0048, 0.1125], [0.0106, 0.0226, 0.0751]]]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_masked_lm(self):
|
||||||
|
model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
|
||||||
|
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
|
vocab_size = 50265
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 6, vocab_size))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[[-2.1313, -3.7285, -2.2407], [-2.7047, -3.3314, -2.6408], [0.0629, -2.5166, -0.3356]]]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_masked_lm_long_input(self):
|
||||||
|
model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
|
||||||
|
input_ids = torch.arange(4096).unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
|
||||||
|
vocab_size = 50265
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 4096, vocab_size))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[[-2.3914, -4.3742, -5.0956], [-4.0988, -4.2384, -7.0406], [-3.1427, -3.7192, -6.6800]]]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||||
Reference in New Issue
Block a user