Decision transformer gym (#15845)
* Created the Decision Transformer Modle * updating tests, copy to other machine * Added last hidden size to Decision Transformer modelling outputs * Removed copy of original DT file * made a temporary change to gpt2 to have it conform with the Decision Transformer version * Updated tests * Ignoring a file used to test the DT model * added comments to config file * added comments and argument descriptions to decision transformer file * Updated doc * Ran "make style" * Remove old model imports * Removed unused imports, cleaned up init file * Update docs/source/model_doc/decision_transformer.mdx added my username Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Reverted changes made to gpt2 * Removed datasets submodule * Update the modeling outputs to include gpt2 attentions, hidden states and last hidden states * Added support for return of hidden states, attentions and return dict of gpt2 model. * Updated tests to include many of the ModelTesterMixin tests. The following tests are skipped: test_generate_without_input_ids, test_pruning, test_resize_embeddings, test_head_masking, test_attention_outputs, test_hidden_states_output, test_inputs_embeds, test_model_common_attributes * Added missing line to the end of gpt2 file * Added an integration test for the Decision Transformer Test performs and autoregressive evaluation for two time steps * Set done and info to _ to fix failing test * Updated integration test to be deterministic and check expected outputs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unnecessary config options * Cleaned up commented code and old comments. * Cleaned up commented code. * Changed DecisionTransformer to Decision Transformer * Added Decision Transformer to the main README file * Added copy of GTP2 called DecisionTranformerGPT2Model * isorted imports * isorted imports * Added model to non-English README files * Ran make fix-copies and corrected some cases. * Updated index file to include Decision Transformer * Added gpt2 model as copy inside the Decision Transformer model file * Added the unit test file to the list of TEST_FILES_WITH_NO_COMMON_TESTS * Deleted redundant checkpoint files (I don't know how these got committed) * Removed testing files. (These should have never been committed) * Removed accidentally committed files * Moved the Decision Transformer test to its own directory * Add type hints for Pegasus (#16324) * Funnel type hints (#16323) * add pt funnel type hints * add tf funnel type hints * Add type hints for ProphetNet PyTorch (#16272) * [GLPN] Improve docs (#16331) * Add link to notebook * Add link * Fix bug Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> * Added type hints for Pytorch Marian calls (#16200) * Added type hinting for forward functions in pytorch marian * typo correction * Removed type hints on functions from BART per Suraj Patil request * fix import pb * fix typo * corrected tuple call * ran black * after fix-copies Some optional tags on primitives were removed, past_key_values in MarianForCausalLM changed from Tuple of Tuple to List * Fixing copies to roformer and pegasus Co-authored-by: Clementine Fourrier <cfourrie@inria.fr> Co-authored-by: matt <rocketknight1@gmail.com> * Moved DecisionTransformOutput to modeling_decision_transformer * Moved the example usage to research project and cleaned comments * Made tests ignore the copy of gpt2 in Decision Transformer * Added module output to modelling decision transformer * removed copied gpt2 model from list of transformers models * Updated tests and created __init__ file for new test location * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unneeded summary type from config file * Fixed copies * Updated pretrained config map to refer to hopper-medium checkpoint * done (#16340) * Added Decision transformer to model docs * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add type annotations for Rembert/Splinter and copies (#16338) * undo black autoformat * minor fix to rembert forward with default * make fix-copies, make quality * Adding types to template model * Removing List from the template types * Remove `Optional` from a couple of types that don't accept `None` Co-authored-by: matt <rocketknight1@gmail.com> * [Bug template] Shift responsibilities for long-range (#16344) * Fix code repetition in serialization guide (#16346) * Adopt framework-specific blocks for content (#16342) * ✨ refactor code samples with framework-specific blocks * ✨ update training.mdx * 🖍 apply feedback * Updates the default branch from master to main (#16326) * Updates the default branch from master to main * Links from `master` to `main` * Typo * Update examples/flax/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Updated model with custom docstring example * Created the Decision Transformer Modle * updating tests, copy to other machine * Added last hidden size to Decision Transformer modelling outputs * Removed copy of original DT file * made a temporary change to gpt2 to have it conform with the Decision Transformer version * Updated tests * Ignoring a file used to test the DT model * added comments to config file * added comments and argument descriptions to decision transformer file * Updated doc * Ran "make style" * Remove old model imports * Removed unused imports, cleaned up init file * Update docs/source/model_doc/decision_transformer.mdx added my username Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Reverted changes made to gpt2 * Removed datasets submodule * Update the modeling outputs to include gpt2 attentions, hidden states and last hidden states * Added support for return of hidden states, attentions and return dict of gpt2 model. * Updated tests to include many of the ModelTesterMixin tests. The following tests are skipped: test_generate_without_input_ids, test_pruning, test_resize_embeddings, test_head_masking, test_attention_outputs, test_hidden_states_output, test_inputs_embeds, test_model_common_attributes * Added missing line to the end of gpt2 file * Added an integration test for the Decision Transformer Test performs and autoregressive evaluation for two time steps * Set done and info to _ to fix failing test * Updated integration test to be deterministic and check expected outputs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unnecessary config options * Cleaned up commented code and old comments. * Cleaned up commented code. * Changed DecisionTransformer to Decision Transformer * Added Decision Transformer to the main README file * Added copy of GTP2 called DecisionTranformerGPT2Model * isorted imports * isorted imports * Added model to non-English README files * Ran make fix-copies and corrected some cases. * Updated index file to include Decision Transformer * Added gpt2 model as copy inside the Decision Transformer model file * Added the unit test file to the list of TEST_FILES_WITH_NO_COMMON_TESTS * Deleted redundant checkpoint files (I don't know how these got committed) * Removed testing files. (These should have never been committed) * Removed accidentally committed files * Moved the Decision Transformer test to its own directory * Moved DecisionTransformOutput to modeling_decision_transformer * Moved the example usage to research project and cleaned comments * Made tests ignore the copy of gpt2 in Decision Transformer * Added module output to modelling decision transformer * removed copied gpt2 model from list of transformers models * Updated tests and created __init__ file for new test location * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unneeded summary type from config file * Fixed copies * Updated pretrained config map to refer to hopper-medium checkpoint * Added Decision transformer to model docs * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Updated model with custom docstring example * Updated copies, config auto, and readme files. Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Dan Tegzes <48134725+Tegzes@users.noreply.github.com> Co-authored-by: Adam Montgomerie <adam@avanssion.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> Co-authored-by: Clementine Fourrier <cfourrie@inria.fr> Co-authored-by: matt <rocketknight1@gmail.com> Co-authored-by: Francesco Saverio Zuppichini <francesco.zuppichini@gmail.com> Co-authored-by: Jacob Dineen <54680234+jacobdineen@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -252,7 +252,8 @@ Current number of checkpoints: ** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/main/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
|
||||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
|
||||
@@ -233,11 +233,12 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
|
||||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/main/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
|
||||
@@ -257,11 +257,12 @@ conda install -c huggingface transformers
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (来自 Facebook) 伴随论文 [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) 由 Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli 发布。
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (来自 Berkeley/Facebook/Google) 伴随论文 [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) 由 Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch 发布。
|
||||
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (来自 Facebook) 伴随论文 [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) 由 Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou 发布。
|
||||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (来自 Facebook) 伴随论文 [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) 由 Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko 发布。
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (来自 Microsoft Research) 伴随论文 [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) 由 Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan 发布。
|
||||
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (来自 HuggingFace), 伴随论文 [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 同样的方法也应用于压缩 GPT-2 到 [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa 到 [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT 到 [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) 和德语版 DistilBERT。
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/main/model_doc/dit)** (来自 Microsoft Research) 伴随论文 [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) 由 Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei 发布。
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (来自 Microsoft Research) 伴随论文 [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) 由 Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei 发布。
|
||||
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (来自 Facebook) 伴随论文 [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) 由 Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih 发布。
|
||||
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (来自 Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning 发布。
|
||||
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (来自 Google Research) 伴随论文 [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) 由 Sascha Rothe, Shashi Narayan, Aliaksei Severyn 发布。
|
||||
|
||||
@@ -269,11 +269,12 @@ conda install -c huggingface transformers
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
|
||||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/main/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
|
||||
@@ -194,6 +194,8 @@
|
||||
title: DeBERTa
|
||||
- local: model_doc/deberta-v2
|
||||
title: DeBERTa-v2
|
||||
- local: model_doc/decision_transformer
|
||||
title: Decision Transformer
|
||||
- local: model_doc/deit
|
||||
title: DeiT
|
||||
- local: model_doc/detr
|
||||
|
||||
@@ -78,6 +78,7 @@ conversion utilities for the following models.
|
||||
1. **[Data2Vec](model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DeBERTa](model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
1. **[DiT](model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DeiT](model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
|
||||
1. **[DETR](model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
@@ -191,6 +192,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Data2VecText | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Decision Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DeiT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
51
docs/source/model_doc/decision_transformer.mdx
Normal file
51
docs/source/model_doc/decision_transformer.mdx
Normal file
@@ -0,0 +1,51 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Decision Transformer
|
||||
|
||||
## Overview
|
||||
|
||||
The Decision Transformer model was proposed in [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345)
|
||||
by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce a framework that abstracts Reinforcement Learning (RL) as a sequence modeling problem.
|
||||
This allows us to draw upon the simplicity and scalability of the Transformer architecture, and associated advances
|
||||
in language modeling such as GPT-x and BERT. In particular, we present Decision Transformer, an architecture that
|
||||
casts the problem of RL as conditional sequence modeling. Unlike prior approaches to RL that fit value functions or
|
||||
compute policy gradients, Decision Transformer simply outputs the optimal actions by leveraging a causally masked
|
||||
Transformer. By conditioning an autoregressive model on the desired return (reward), past states, and actions, our
|
||||
Decision Transformer model can generate future actions that achieve the desired return. Despite its simplicity,
|
||||
Decision Transformer matches or exceeds the performance of state-of-the-art model-free offline RL baselines on
|
||||
Atari, OpenAI Gym, and Key-to-Door tasks.*
|
||||
|
||||
Tips:
|
||||
|
||||
This version of the model is for tasks where the state is a vector, image-based states will come soon.
|
||||
|
||||
This model was contributed by [edbeeching](https://huggingface.co/edbeeching). The original code can be found [here](https://github.com/kzl/decision-transformer).
|
||||
|
||||
## DecisionTransformerConfig
|
||||
|
||||
[[autodoc]] DecisionTransformerConfig
|
||||
|
||||
|
||||
## DecisionTransformerGPT2Model
|
||||
|
||||
[[autodoc]] DecisionTransformerGPT2Model
|
||||
- forward
|
||||
|
||||
## DecisionTransformerModel
|
||||
|
||||
[[autodoc]] DecisionTransformerModel
|
||||
- forward
|
||||
240
examples/research_projects/decision_transformer/requirements.txt
Normal file
240
examples/research_projects/decision_transformer/requirements.txt
Normal file
@@ -0,0 +1,240 @@
|
||||
absl-py==1.0.0
|
||||
aiohttp==3.8.1
|
||||
aiosignal==1.2.0
|
||||
alembic==1.7.7
|
||||
appdirs==1.4.4
|
||||
APScheduler==3.9.1
|
||||
arrow==1.2.2
|
||||
asttokens==2.0.5
|
||||
astunparse==1.6.3
|
||||
async-timeout==4.0.2
|
||||
attrs==21.4.0
|
||||
audioread==2.1.9
|
||||
autopage==0.5.0
|
||||
backcall==0.2.0
|
||||
backoff==1.11.1
|
||||
backports.zoneinfo==0.2.1
|
||||
binaryornot==0.4.4
|
||||
black==22.1.0
|
||||
boto3==1.16.34
|
||||
botocore==1.19.63
|
||||
Brotli==1.0.9
|
||||
cachetools==5.0.0
|
||||
certifi==2021.10.8
|
||||
cffi==1.15.0
|
||||
chardet==4.0.0
|
||||
charset-normalizer==2.0.12
|
||||
chex==0.1.1
|
||||
click==8.0.4
|
||||
cliff==3.10.1
|
||||
clldutils==3.11.1
|
||||
cloudpickle==2.0.0
|
||||
cmaes==0.8.2
|
||||
cmd2==2.4.0
|
||||
codecarbon==1.2.0
|
||||
colorlog==6.6.0
|
||||
cookiecutter==1.7.2
|
||||
cryptography==36.0.2
|
||||
csvw==2.0.0
|
||||
cycler==0.11.0
|
||||
Cython==0.29.28
|
||||
dash==2.3.0
|
||||
dash-bootstrap-components==1.0.3
|
||||
dash-core-components==2.0.0
|
||||
dash-html-components==2.0.0
|
||||
dash-table==5.0.0
|
||||
datasets==2.0.0
|
||||
decorator==5.1.1
|
||||
Deprecated==1.2.13
|
||||
dill==0.3.4
|
||||
dlinfo==1.2.1
|
||||
dm-tree==0.1.6
|
||||
docker==4.4.4
|
||||
execnet==1.9.0
|
||||
executing==0.8.3
|
||||
faiss-cpu==1.7.2
|
||||
fasteners==0.17.3
|
||||
filelock==3.6.0
|
||||
fire==0.4.0
|
||||
flake8==4.0.1
|
||||
Flask==2.0.3
|
||||
Flask-Compress==1.11
|
||||
flatbuffers==2.0
|
||||
flax==0.4.0
|
||||
fonttools==4.31.1
|
||||
frozenlist==1.3.0
|
||||
fsspec==2022.2.0
|
||||
fugashi==1.1.2
|
||||
gast==0.5.3
|
||||
gitdb==4.0.9
|
||||
GitPython==3.1.18
|
||||
glfw==2.5.1
|
||||
google-auth==2.6.2
|
||||
google-auth-oauthlib==0.4.6
|
||||
google-pasta==0.2.0
|
||||
greenlet==1.1.2
|
||||
grpcio==1.44.0
|
||||
gym==0.23.1
|
||||
gym-notices==0.0.6
|
||||
h5py==3.6.0
|
||||
huggingface-hub==0.4.0
|
||||
hypothesis==6.39.4
|
||||
idna==3.3
|
||||
imageio==2.16.1
|
||||
importlib-metadata==4.11.3
|
||||
importlib-resources==5.4.0
|
||||
iniconfig==1.1.1
|
||||
ipadic==1.0.0
|
||||
ipython==8.1.1
|
||||
isodate==0.6.1
|
||||
isort==5.10.1
|
||||
itsdangerous==2.1.1
|
||||
jax==0.3.4
|
||||
jaxlib==0.3.2
|
||||
jedi==0.18.1
|
||||
Jinja2==2.11.3
|
||||
jinja2-time==0.2.0
|
||||
jmespath==0.10.0
|
||||
joblib==1.1.0
|
||||
jsonschema==4.4.0
|
||||
keras==2.8.0
|
||||
Keras-Preprocessing==1.1.2
|
||||
kiwisolver==1.4.0
|
||||
kubernetes==12.0.1
|
||||
libclang==13.0.0
|
||||
librosa==0.9.1
|
||||
llvmlite==0.38.0
|
||||
Mako==1.2.0
|
||||
Markdown==3.3.6
|
||||
MarkupSafe==1.1.1
|
||||
matplotlib==3.5.1
|
||||
matplotlib-inline==0.1.3
|
||||
mccabe==0.6.1
|
||||
msgpack==1.0.3
|
||||
mujoco-py==2.1.2.14
|
||||
multidict==6.0.2
|
||||
multiprocess==0.70.12.2
|
||||
mypy-extensions==0.4.3
|
||||
nltk==3.7
|
||||
numba==0.55.1
|
||||
numpy==1.22.3
|
||||
oauthlib==3.2.0
|
||||
onnx==1.11.0
|
||||
onnxconverter-common==1.9.0
|
||||
opt-einsum==3.3.0
|
||||
optax==0.1.1
|
||||
optuna==2.10.0
|
||||
packaging==21.3
|
||||
pandas==1.4.1
|
||||
parameterized==0.8.1
|
||||
parso==0.8.3
|
||||
pathspec==0.9.0
|
||||
pbr==5.8.1
|
||||
pexpect==4.8.0
|
||||
phonemizer==3.0.1
|
||||
pickleshare==0.7.5
|
||||
Pillow==9.0.1
|
||||
Pint==0.16.1
|
||||
plac==1.3.4
|
||||
platformdirs==2.5.1
|
||||
plotly==5.6.0
|
||||
pluggy==1.0.0
|
||||
pooch==1.6.0
|
||||
portalocker==2.0.0
|
||||
poyo==0.5.0
|
||||
prettytable==3.2.0
|
||||
prompt-toolkit==3.0.28
|
||||
protobuf==3.19.4
|
||||
psutil==5.9.0
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
py==1.11.0
|
||||
py-cpuinfo==8.0.0
|
||||
pyarrow==7.0.0
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycodestyle==2.8.0
|
||||
pycparser==2.21
|
||||
pyctcdecode==0.3.0
|
||||
pyflakes==2.4.0
|
||||
Pygments==2.11.2
|
||||
pygtrie==2.4.2
|
||||
pynvml==11.4.1
|
||||
pyOpenSSL==22.0.0
|
||||
pyparsing==3.0.7
|
||||
pyperclip==1.8.2
|
||||
pypng==0.0.21
|
||||
pyrsistent==0.18.1
|
||||
pytest==7.1.1
|
||||
pytest-forked==1.4.0
|
||||
pytest-timeout==2.1.0
|
||||
pytest-xdist==2.5.0
|
||||
python-dateutil==2.8.2
|
||||
python-slugify==6.1.1
|
||||
pytz==2022.1
|
||||
pytz-deprecation-shim==0.1.0.post0
|
||||
PyYAML==6.0
|
||||
ray==1.11.0
|
||||
redis==4.1.4
|
||||
regex==2022.3.15
|
||||
requests==2.27.1
|
||||
requests-oauthlib==1.3.1
|
||||
resampy==0.2.2
|
||||
responses==0.18.0
|
||||
rfc3986==1.5.0
|
||||
rouge-score==0.0.4
|
||||
rsa==4.8
|
||||
s3transfer==0.3.7
|
||||
sacrebleu==1.5.1
|
||||
sacremoses==0.0.49
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.8.0
|
||||
segments==2.2.0
|
||||
sentencepiece==0.1.96
|
||||
sigopt==8.2.0
|
||||
six==1.16.0
|
||||
smmap==5.0.0
|
||||
sortedcontainers==2.4.0
|
||||
SoundFile==0.10.3.post1
|
||||
SQLAlchemy==1.4.32
|
||||
stack-data==0.2.0
|
||||
stevedore==3.5.0
|
||||
tabulate==0.8.9
|
||||
tenacity==8.0.1
|
||||
tensorboard==2.8.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
tensorboardX==2.5
|
||||
tensorflow==2.8.0
|
||||
tensorflow-io-gcs-filesystem==0.24.0
|
||||
termcolor==1.1.0
|
||||
text-unidecode==1.3
|
||||
tf-estimator-nightly==2.8.0.dev2021122109
|
||||
tf2onnx==1.9.3
|
||||
threadpoolctl==3.1.0
|
||||
timeout-decorator==0.5.0
|
||||
timm==0.5.4
|
||||
tokenizers==0.11.6
|
||||
tomli==2.0.1
|
||||
toolz==0.11.2
|
||||
torch==1.11.0
|
||||
torchaudio==0.11.0
|
||||
torchvision==0.12.0
|
||||
tqdm==4.63.0
|
||||
traitlets==5.1.1
|
||||
-e git+git@github.com:edbeeching/transformers.git@77b90113ca0a0e4058b046796c874bdc98f1da61#egg=transformers
|
||||
typing-extensions==4.1.1
|
||||
tzdata==2022.1
|
||||
tzlocal==4.1
|
||||
unidic==1.1.0
|
||||
unidic-lite==1.0.8
|
||||
uritemplate==4.1.1
|
||||
urllib3==1.26.9
|
||||
wasabi==0.9.0
|
||||
wcwidth==0.2.5
|
||||
websocket-client==1.3.1
|
||||
Werkzeug==2.0.3
|
||||
wrapt==1.14.0
|
||||
xxhash==3.0.0
|
||||
yarl==1.7.2
|
||||
zipp==3.7.0
|
||||
@@ -0,0 +1,173 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import gym
|
||||
from mujoco_py import GlfwContext
|
||||
from transformers import DecisionTransformerModel
|
||||
|
||||
|
||||
GlfwContext(offscreen=True) # Create a window to init GLFW.
|
||||
|
||||
|
||||
def get_action(model, states, actions, rewards, returns_to_go, timesteps):
|
||||
# we don't care about the past rewards in this model
|
||||
|
||||
states = states.reshape(1, -1, model.config.state_dim)
|
||||
actions = actions.reshape(1, -1, model.config.act_dim)
|
||||
returns_to_go = returns_to_go.reshape(1, -1, 1)
|
||||
timesteps = timesteps.reshape(1, -1)
|
||||
|
||||
if model.config.max_length is not None:
|
||||
states = states[:, -model.config.max_length :]
|
||||
actions = actions[:, -model.config.max_length :]
|
||||
returns_to_go = returns_to_go[:, -model.config.max_length :]
|
||||
timesteps = timesteps[:, -model.config.max_length :]
|
||||
|
||||
# pad all tokens to sequence length
|
||||
attention_mask = torch.cat(
|
||||
[torch.zeros(model.config.max_length - states.shape[1]), torch.ones(states.shape[1])]
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
|
||||
states = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(states.shape[0], model.config.max_length - states.shape[1], model.config.state_dim),
|
||||
device=states.device,
|
||||
),
|
||||
states,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
actions = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(actions.shape[0], model.config.max_length - actions.shape[1], model.config.act_dim),
|
||||
device=actions.device,
|
||||
),
|
||||
actions,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
returns_to_go = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(returns_to_go.shape[0], model.config.max_length - returns_to_go.shape[1], 1),
|
||||
device=returns_to_go.device,
|
||||
),
|
||||
returns_to_go,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
timesteps = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(timesteps.shape[0], model.config.max_length - timesteps.shape[1]), device=timesteps.device
|
||||
),
|
||||
timesteps,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.long)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
_, action_preds, _ = model(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
returns_to_go=returns_to_go,
|
||||
timesteps=timesteps,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
return action_preds[0, -1]
|
||||
|
||||
|
||||
# build the environment
|
||||
|
||||
env = gym.make("Hopper-v3")
|
||||
state_dim = env.observation_space.shape[0]
|
||||
act_dim = env.action_space.shape[0]
|
||||
max_ep_len = 1000
|
||||
device = "cuda"
|
||||
scale = 1000.0 # normalization for rewards/returns
|
||||
TARGET_RETURN = 3600 / scale # evaluation conditioning targets, 3600 is reasonable from the paper LINK
|
||||
state_mean = np.array(
|
||||
[
|
||||
1.311279,
|
||||
-0.08469521,
|
||||
-0.5382719,
|
||||
-0.07201576,
|
||||
0.04932366,
|
||||
2.1066856,
|
||||
-0.15017354,
|
||||
0.00878345,
|
||||
-0.2848186,
|
||||
-0.18540096,
|
||||
-0.28461286,
|
||||
]
|
||||
)
|
||||
state_std = np.array(
|
||||
[
|
||||
0.17790751,
|
||||
0.05444621,
|
||||
0.21297139,
|
||||
0.14530419,
|
||||
0.6124444,
|
||||
0.85174465,
|
||||
1.4515252,
|
||||
0.6751696,
|
||||
1.536239,
|
||||
1.6160746,
|
||||
5.6072536,
|
||||
]
|
||||
)
|
||||
state_mean = torch.from_numpy(state_mean).to(device=device)
|
||||
state_std = torch.from_numpy(state_std).to(device=device)
|
||||
|
||||
# Create the decision transformer model
|
||||
model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
for ep in range(10):
|
||||
episode_return, episode_length = 0, 0
|
||||
state = env.reset()
|
||||
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
|
||||
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
|
||||
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
|
||||
rewards = torch.zeros(0, device=device, dtype=torch.float32)
|
||||
|
||||
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
|
||||
for t in range(max_ep_len):
|
||||
env.render()
|
||||
# add padding
|
||||
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
|
||||
rewards = torch.cat([rewards, torch.zeros(1, device=device)])
|
||||
|
||||
action = get_action(
|
||||
model,
|
||||
(states.to(dtype=torch.float32) - state_mean) / state_std,
|
||||
actions.to(dtype=torch.float32),
|
||||
rewards.to(dtype=torch.float32),
|
||||
target_return.to(dtype=torch.float32),
|
||||
timesteps.to(dtype=torch.long),
|
||||
)
|
||||
actions[-1] = action
|
||||
action = action.detach().cpu().numpy()
|
||||
|
||||
state, reward, done, _ = env.step(action)
|
||||
|
||||
cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
|
||||
states = torch.cat([states, cur_state], dim=0)
|
||||
rewards[-1] = reward
|
||||
|
||||
pred_return = target_return[0, -1] - (reward / scale)
|
||||
target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
|
||||
timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)
|
||||
|
||||
episode_return += reward
|
||||
episode_length += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
@@ -173,6 +173,7 @@ _import_structure = {
|
||||
"models.data2vec": ["DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Data2VecAudioConfig", "Data2VecTextConfig"],
|
||||
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"],
|
||||
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
|
||||
"models.decision_transformer": ["DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "DecisionTransformerConfig"],
|
||||
"models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
|
||||
"models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"],
|
||||
"models.dialogpt": [],
|
||||
@@ -901,6 +902,15 @@ if is_torch_available():
|
||||
"DebertaV2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.decision_transformer"].extend(
|
||||
[
|
||||
"DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DecisionTransformerGPT2Model",
|
||||
"DecisionTransformerGPT2PreTrainedModel",
|
||||
"DecisionTransformerModel",
|
||||
"DecisionTransformerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.deit"].extend(
|
||||
[
|
||||
"DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -2509,6 +2519,10 @@ if TYPE_CHECKING:
|
||||
from .models.data2vec import DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Data2VecAudioConfig, Data2VecTextConfig
|
||||
from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer
|
||||
from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
|
||||
from .models.decision_transformer import (
|
||||
DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DecisionTransformerConfig,
|
||||
)
|
||||
from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
|
||||
from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
|
||||
from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer
|
||||
@@ -3128,6 +3142,13 @@ if TYPE_CHECKING:
|
||||
DebertaV2Model,
|
||||
DebertaV2PreTrainedModel,
|
||||
)
|
||||
from .models.decision_transformer import (
|
||||
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DecisionTransformerGPT2Model,
|
||||
DecisionTransformerGPT2PreTrainedModel,
|
||||
DecisionTransformerModel,
|
||||
DecisionTransformerPreTrainedModel,
|
||||
)
|
||||
from .models.deit import (
|
||||
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DeiTForImageClassification,
|
||||
|
||||
@@ -43,6 +43,7 @@ from . import (
|
||||
data2vec,
|
||||
deberta,
|
||||
deberta_v2,
|
||||
decision_transformer,
|
||||
deit,
|
||||
detr,
|
||||
dialogpt,
|
||||
|
||||
@@ -29,8 +29,10 @@ logger = logging.get_logger(__name__)
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("glpn", "GLPNConfig"),
|
||||
("maskformer", "MaskFormerConfig"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("poolformer", "PoolFormerConfig"),
|
||||
("convnext", "ConvNextConfig"),
|
||||
("van", "VanConfig"),
|
||||
@@ -222,6 +224,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("decision_transformer", "Decision Transformer"),
|
||||
("glpn", "GLPN"),
|
||||
("maskformer", "MaskFormer"),
|
||||
("poolformer", "PoolFormer"),
|
||||
|
||||
@@ -28,8 +28,11 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("glpn", "GLPNModel"),
|
||||
("maskformer", "MaskFormerModel"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("convnext", "ConvNextModel"),
|
||||
("van", "VanModel"),
|
||||
|
||||
60
src/transformers/models/decision_transformer/__init__.py
Normal file
60
src/transformers/models/decision_transformer/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...file_utils import _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_decision_transformer": [
|
||||
"DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"DecisionTransformerConfig",
|
||||
],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_decision_transformer"] = [
|
||||
"DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DecisionTransformerGPT2Model",
|
||||
"DecisionTransformerGPT2PreTrainedModel",
|
||||
"DecisionTransformerModel",
|
||||
"DecisionTransformerPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_decision_transformer import (
|
||||
DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DecisionTransformerConfig,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_decision_transformer import (
|
||||
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DecisionTransformerGPT2Model,
|
||||
DecisionTransformerGPT2PreTrainedModel,
|
||||
DecisionTransformerModel,
|
||||
DecisionTransformerPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
@@ -0,0 +1,174 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Decision Transformer model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DECISION_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"edbeeching/decision-transformer-gym-hopper-medium": "https://huggingface.co/edbeeching/decision-transformer-gym-hopper-medium/resolve/main/config.json",
|
||||
# See all DecisionTransformer models at https://huggingface.co/models?filter=decision_transformer
|
||||
}
|
||||
|
||||
|
||||
class DecisionTransformerConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`DecisionTransformerModel`]. It is used to
|
||||
instantiate a Decision Transformer model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the standard
|
||||
DecisionTransformer architecture. Many of the config options are used to instatiate the GPT2 model that is used as
|
||||
part of the architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
state_dim (`int`, *optional*, defaults to 17):
|
||||
The state size for the RL environment
|
||||
act_dim (`int`, *optional*, defaults to 4):
|
||||
The size of the output action space
|
||||
hidden_size (`int`, *optional*, defaults to 128):
|
||||
The size of the hidden layers
|
||||
max_ep_len (`int`, *optional*, defaults to 4096):
|
||||
The maximum length of an episode in the environment
|
||||
action_tanh (`bool`, *optional*, defaults to True):
|
||||
Whether to use a tanh activation on action prediction
|
||||
vocab_size (`int`, *optional*, defaults to 50257):
|
||||
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`DecisionTransformerModel`].
|
||||
n_positions (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
n_embd (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_layer (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
n_head (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
n_inner (`int`, *optional*):
|
||||
Dimensionality of the inner feed-forward layers. If unset, will default to 4 times `n_embd`.
|
||||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the embeddings.
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
||||
Scale attention weights by dividing by sqrt(hidden_size)..
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
|
||||
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
|
||||
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
||||
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
|
||||
dot-product/softmax to float() when training with mixed precision.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DecisionTransformerModel, DecisionTransformerConfig
|
||||
|
||||
>>> # Initializing a DecisionTransformer configuration
|
||||
>>> configuration = DecisionTransformerConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = DecisionTransformerConfig(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "decision_transformer"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dim=17,
|
||||
act_dim=4,
|
||||
hidden_size=128,
|
||||
max_ep_len=4096,
|
||||
action_tanh=True,
|
||||
vocab_size=1,
|
||||
n_positions=1024,
|
||||
n_embd=768,
|
||||
n_layer=3,
|
||||
n_head=1,
|
||||
n_inner=None,
|
||||
activation_function="relu",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
summary_type="cls_index",
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
scale_attn_weights=True,
|
||||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
scale_attn_by_inverse_layer_idx=False,
|
||||
reorder_and_upcast_attn=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
self.state_dim = state_dim
|
||||
self.act_dim = act_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.max_ep_len = max_ep_len
|
||||
self.action_tanh = action_tanh
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_inner = n_inner
|
||||
self.activation_function = activation_function
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.summary_type = summary_type
|
||||
self.summary_use_proj = summary_use_proj
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
||||
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
953
src/transformers/models/decision_transformer/modeling_decision_transformer.py
Executable file
953
src/transformers/models/decision_transformer/modeling_decision_transformer.py
Executable file
@@ -0,0 +1,953 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch DecisionTransformer model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
is_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
is_amp_available = False
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from .configuration_decision_transformer import DecisionTransformerConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
|
||||
_CONFIG_FOR_DOC = "DecisionTransformerConfig"
|
||||
|
||||
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"edbeeching/decision-transformer-gym-hopper-medium",
|
||||
# See all DecisionTransformer models at https://huggingface.co/models?filter=decision_transformer
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2
|
||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
"""Load tf checkpoints in a pytorch model"""
|
||||
try:
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array.squeeze())
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name[6:] # skip "model/"
|
||||
name = name.split("/")
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
||||
scope_names = re.split(r"(\d+)", m_name)
|
||||
else:
|
||||
scope_names = [m_name]
|
||||
if scope_names[0] == "w" or scope_names[0] == "g":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif scope_names[0] == "b":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
if len(scope_names) >= 2:
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert (
|
||||
pointer.shape == array.shape
|
||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f"Initialize PyTorch weight {name}")
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
|
||||
class DecisionTransformerGPT2Attention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
||||
super().__init__()
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale_attn_weights = config.scale_attn_weights
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
# Layer-wise attention scaling, reordering, and upcasting
|
||||
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
||||
self.layer_idx = layer_idx
|
||||
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
||||
|
||||
if self.is_cross_attention:
|
||||
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
||||
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
||||
else:
|
||||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
||||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
||||
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
||||
|
||||
# Prune conv1d layers
|
||||
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
||||
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
||||
|
||||
# Update hyper params
|
||||
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
||||
self.num_heads = self.num_heads - len(heads)
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / (value.size(-1) ** 0.5)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
attn_weights = attn_weights / float(self.layer_idx + 1)
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
||||
attn_weights = attn_weights.type(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
||||
bsz, num_heads, q_seq_len, dk = query.size()
|
||||
_, _, k_seq_len, _ = key.size()
|
||||
|
||||
# Preallocate attn_weights for `baddbmm`
|
||||
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
||||
|
||||
# Compute Scale Factor
|
||||
scale_factor = 1.0
|
||||
if self.scale_attn_weights:
|
||||
scale_factor /= float(value.size(-1)) ** 0.5
|
||||
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
scale_factor /= float(self.layer_idx + 1)
|
||||
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
if is_amp_available:
|
||||
with autocast(enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
else:
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
||||
if attn_weights.dtype != torch.float32:
|
||||
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
||||
attn_weights = attn_weights.type(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden_size dim into attn_head_size and num_heads
|
||||
"""
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
||||
"""
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
if self.reorder_and_upcast_attn:
|
||||
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
||||
else:
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
|
||||
class DecisionTransformerGPT2MLP(nn.Module):
|
||||
def __init__(self, intermediate_size, config):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
||||
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
|
||||
class DecisionTransformerGPT2Block(nn.Module):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
if config.add_cross_attention:
|
||||
self.crossattention = DecisionTransformerGPT2Attention(
|
||||
config, is_cross_attention=True, layer_idx=layer_idx
|
||||
)
|
||||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# add one self-attention block for cross-attention
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
||||
"cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_cross_attn(hidden_states)
|
||||
cross_attn_outputs = self.crossattention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = cross_attn_outputs[0]
|
||||
# residual connection
|
||||
hidden_states = residual + attn_output
|
||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
|
||||
|
||||
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = DecisionTransformerConfig
|
||||
load_tf_weights = load_tf_weights_in_gpt2
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, Conv1D)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if "c_proj" in name and "weight" in name:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, DecisionTransformerGPT2Model):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList(
|
||||
[DecisionTransformerGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecisionTransformerOutput(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
|
||||
Environment state predictions
|
||||
action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
|
||||
Model action predictions
|
||||
return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
|
||||
Predicted returns for each state
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
state_preds: torch.FloatTensor = None
|
||||
action_preds: torch.FloatTensor = None
|
||||
return_preds: torch.FloatTensor = None
|
||||
hidden_states: torch.FloatTensor = None
|
||||
attentions: torch.FloatTensor = None
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
|
||||
|
||||
class DecisionTransformerPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = DecisionTransformerConfig
|
||||
base_model_prefix = "decision_transformer"
|
||||
main_input_name = "states"
|
||||
supports_gradient_checkpointing = False
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
DECISION_TRANSFORMER_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`~DecisionTransformerConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
|
||||
The states for each step in the trajectory
|
||||
actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
|
||||
The actions taken by the "expert" policy for the current state, these are masked for auto regressive
|
||||
prediction
|
||||
rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
|
||||
The rewards for each state, action
|
||||
returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
|
||||
The returns for each state in the trajectory
|
||||
timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
|
||||
The timestep for each step in the trajectory
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, episode_length)`):
|
||||
Masking, used to mask the actions when performing autoregressive prediction
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings("The Decision Transformer Model", DECISION_TRANSFORMER_START_DOCSTRING)
|
||||
class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
|
||||
"""
|
||||
|
||||
The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL
|
||||
setting. Refer to the paper for more details: https://arxiv.org/abs/2106.01345
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
# note: the only difference between this GPT2Model and the default Huggingface version
|
||||
# is that the positional embeddings are removed (since we'll add those ourselves)
|
||||
self.encoder = DecisionTransformerGPT2Model(config)
|
||||
|
||||
self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)
|
||||
self.embed_return = torch.nn.Linear(1, config.hidden_size)
|
||||
self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
|
||||
self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)
|
||||
|
||||
self.embed_ln = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
# note: we don't predict states or returns for the paper
|
||||
self.predict_state = torch.nn.Linear(config.hidden_size, config.state_dim)
|
||||
self.predict_action = nn.Sequential(
|
||||
*([nn.Linear(config.hidden_size, config.act_dim)] + ([nn.Tanh()] if config.action_tanh else []))
|
||||
)
|
||||
self.predict_return = torch.nn.Linear(config.hidden_size, 1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(DECISION_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
states=None,
|
||||
actions=None,
|
||||
rewards=None,
|
||||
returns_to_go=None,
|
||||
timesteps=None,
|
||||
attention_mask=None,
|
||||
output_hidden_states=None,
|
||||
output_attentions=None,
|
||||
return_dict=None,
|
||||
) -> Union[Tuple, DecisionTransformerOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import DecisionTransformerModel
|
||||
>>> import torch
|
||||
|
||||
>>> model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
|
||||
>>> # evaluation
|
||||
>>> model = model.to(device)
|
||||
>>> model.eval()
|
||||
|
||||
>>> env = gym.make("Hopper-v3")
|
||||
>>> state_dim = env.observation_space.shape[0]
|
||||
>>> act_dim = env.action_space.shape[0]
|
||||
|
||||
>>> state = env.reset()
|
||||
>>> states = torch.from_numpy(state).reshape(1, 1, state_dim).to(device=device, dtype=torch.float32)
|
||||
>>> actions = torch.zeros((1, 1, act_dim), device=device, dtype=torch.float32)
|
||||
>>> rewards = torch.zeros(1, 1, device=device, dtype=torch.float32)
|
||||
>>> target_return = torch.tensor(TARGET_RETURN, dtype=torch.float32).reshape(1, 1)
|
||||
>>> timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
|
||||
>>> attention_mask = torch.zeros(1, 1, device=device, dtype=torch.float32)
|
||||
|
||||
>>> # forward pass
|
||||
>>> with torch.no_grad():
|
||||
... state_preds, action_preds, return_preds = model(
|
||||
... states=states,
|
||||
... actions=actions,
|
||||
... rewards=rewards,
|
||||
... returns_to_go=target_return,
|
||||
... timesteps=timesteps,
|
||||
... attention_mask=attention_mask,
|
||||
... return_dict=False,
|
||||
... )
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = states.shape[0], states.shape[1]
|
||||
|
||||
if attention_mask is None:
|
||||
# attention mask for GPT: 1 if can be attended to, 0 if not
|
||||
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
|
||||
|
||||
# embed each modality with a different head
|
||||
state_embeddings = self.embed_state(states)
|
||||
action_embeddings = self.embed_action(actions)
|
||||
returns_embeddings = self.embed_return(returns_to_go)
|
||||
time_embeddings = self.embed_timestep(timesteps)
|
||||
|
||||
# time embeddings are treated similar to positional embeddings
|
||||
state_embeddings = state_embeddings + time_embeddings
|
||||
action_embeddings = action_embeddings + time_embeddings
|
||||
returns_embeddings = returns_embeddings + time_embeddings
|
||||
|
||||
# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
|
||||
# which works nice in an autoregressive sense since states predict actions
|
||||
stacked_inputs = (
|
||||
torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size, 3 * seq_length, self.hidden_size)
|
||||
)
|
||||
stacked_inputs = self.embed_ln(stacked_inputs)
|
||||
|
||||
# to make the attention mask fit the stacked inputs, have to stack it as well
|
||||
stacked_attention_mask = (
|
||||
torch.stack((attention_mask, attention_mask, attention_mask), dim=1)
|
||||
.permute(0, 2, 1)
|
||||
.reshape(batch_size, 3 * seq_length)
|
||||
)
|
||||
device = stacked_inputs.device
|
||||
# we feed in the input embeddings (not word indices as in NLP) to the model
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=stacked_inputs,
|
||||
attention_mask=stacked_attention_mask,
|
||||
position_ids=torch.zeros(stacked_attention_mask.shape, device=device, dtype=torch.long),
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
x = encoder_outputs[0]
|
||||
|
||||
# reshape x so that the second dimension corresponds to the original
|
||||
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
|
||||
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
|
||||
|
||||
# get predictions
|
||||
return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
|
||||
state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
|
||||
action_preds = self.predict_action(x[:, 1]) # predict next action given state
|
||||
if not return_dict:
|
||||
return (state_preds, action_preds, return_preds)
|
||||
|
||||
return DecisionTransformerOutput(
|
||||
last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
state_preds=state_preds,
|
||||
action_preds=action_preds,
|
||||
return_preds=return_preds,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
@@ -1425,6 +1425,37 @@ class DebertaV2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class DecisionTransformerGPT2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DecisionTransformerGPT2PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DecisionTransformerModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DecisionTransformerPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
0
tests/decision_transformer/__init__.py
Normal file
0
tests/decision_transformer/__init__.py
Normal file
248
tests/decision_transformer/test_modeling_decision_transformer.py
Normal file
248
tests/decision_transformer/test_modeling_decision_transformer.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch DecisionTransformer model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import DecisionTransformerConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ..generation.test_generation_utils import GenerationTesterMixin
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import DecisionTransformerModel
|
||||
from transformers.models.decision_transformer.modeling_decision_transformer import (
|
||||
DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
|
||||
class DecisionTransformerModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
act_dim=6,
|
||||
state_dim=17,
|
||||
hidden_size=23,
|
||||
max_length=11,
|
||||
is_training=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.act_dim = act_dim
|
||||
self.state_dim = state_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.max_length = max_length
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
states = floats_tensor((self.batch_size, self.seq_length, self.state_dim))
|
||||
actions = floats_tensor((self.batch_size, self.seq_length, self.act_dim))
|
||||
rewards = floats_tensor((self.batch_size, self.seq_length, 1))
|
||||
returns_to_go = floats_tensor((self.batch_size, self.seq_length, 1))
|
||||
timesteps = ids_tensor((self.batch_size, self.seq_length), vocab_size=1000)
|
||||
attention_mask = random_attention_mask((self.batch_size, self.seq_length))
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return (
|
||||
config,
|
||||
states,
|
||||
actions,
|
||||
rewards,
|
||||
returns_to_go,
|
||||
timesteps,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return DecisionTransformerConfig(
|
||||
batch_size=self.batch_size,
|
||||
seq_length=self.seq_length,
|
||||
act_dim=self.act_dim,
|
||||
state_dim=self.state_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
states,
|
||||
actions,
|
||||
rewards,
|
||||
returns_to_go,
|
||||
timesteps,
|
||||
attention_mask,
|
||||
):
|
||||
model = DecisionTransformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(states, actions, rewards, returns_to_go, timesteps, attention_mask)
|
||||
|
||||
self.parent.assertEqual(result.state_preds.shape, states.shape)
|
||||
self.parent.assertEqual(result.action_preds.shape, actions.shape)
|
||||
self.parent.assertEqual(result.return_preds.shape, returns_to_go.shape)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.seq_length * 3, self.hidden_size)
|
||||
) # seq length *3 as there are 3 modelities: states, returns and actions
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
states,
|
||||
actions,
|
||||
rewards,
|
||||
returns_to_go,
|
||||
timesteps,
|
||||
attention_mask,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {
|
||||
"states": states,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"returns_to_go": returns_to_go,
|
||||
"timesteps": timesteps,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class DecisionTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (DecisionTransformerModel,) if is_torch_available() else ()
|
||||
all_generative_model_classes = ()
|
||||
|
||||
# Ignoring of a failing test from GenerationTesterMixin, as the model does not use inputs_ids
|
||||
test_generate_without_input_ids = False
|
||||
|
||||
# Ignoring of a failing tests from ModelTesterMixin, as the model does not implement these features
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_attention_outputs = False
|
||||
test_hidden_states_output = False
|
||||
test_inputs_embeds = False
|
||||
test_model_common_attributes = False
|
||||
test_gradient_checkpointing = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DecisionTransformerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DecisionTransformerConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = DecisionTransformerModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = [
|
||||
"states",
|
||||
"actions",
|
||||
"rewards",
|
||||
"returns_to_go",
|
||||
"timesteps",
|
||||
"attention_mask",
|
||||
]
|
||||
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
|
||||
@require_torch
|
||||
class DecisionTransformerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_autoregressive_prediction(self):
|
||||
"""
|
||||
An integration test that performs autoregressive prediction of state, action and return
|
||||
from a sequence of state, actions and returns. Test is performed over two timesteps.
|
||||
|
||||
"""
|
||||
|
||||
NUM_STEPS = 2 # number of steps of autoregressive prediction we will perform
|
||||
TARGET_RETURN = 10 # defined by the RL environment, may be normalized
|
||||
model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-expert")
|
||||
model = model.to(torch_device)
|
||||
config = model.config
|
||||
torch.manual_seed(0)
|
||||
state = torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32) # env.reset()
|
||||
|
||||
expected_outputs = torch.tensor([[0.2384, -0.2955, 0.8741], [0.6765, -0.0793, -0.1298]], device=torch_device)
|
||||
|
||||
returns_to_go = torch.tensor(TARGET_RETURN, device=torch_device, dtype=torch.float32).reshape(1, 1, 1)
|
||||
states = state
|
||||
actions = torch.zeros(1, 0, config.act_dim, device=torch_device, dtype=torch.float32)
|
||||
rewards = torch.zeros(1, 0, device=torch_device, dtype=torch.float32)
|
||||
timesteps = torch.tensor(0, device=torch_device, dtype=torch.long).reshape(1, 1)
|
||||
|
||||
for step in range(NUM_STEPS):
|
||||
actions = torch.cat([actions, torch.zeros(1, 1, config.act_dim, device=torch_device)], dim=1)
|
||||
rewards = torch.cat([rewards, torch.zeros(1, 1, device=torch_device)], dim=1)
|
||||
|
||||
attention_mask = torch.ones(1, states.shape[1]).to(dtype=torch.long, device=states.device)
|
||||
|
||||
with torch.no_grad():
|
||||
_, action_pred, _ = model(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
returns_to_go=returns_to_go,
|
||||
timesteps=timesteps,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(action_pred.shape, actions.shape)
|
||||
self.assertTrue(torch.allclose(action_pred[0, -1], expected_outputs[step], atol=1e-4))
|
||||
state, reward, _, _ = ( # env.step(action)
|
||||
torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32),
|
||||
1.0,
|
||||
False,
|
||||
{},
|
||||
)
|
||||
|
||||
actions[-1] = action_pred[0, -1]
|
||||
states = torch.cat([states, state], dim=1)
|
||||
pred_return = returns_to_go[0, -1] - reward
|
||||
returns_to_go = torch.cat([returns_to_go, pred_return.reshape(1, 1, 1)], dim=1)
|
||||
timesteps = torch.cat(
|
||||
[timesteps, torch.ones((1, 1), device=torch_device, dtype=torch.long) * (step + 1)], dim=1
|
||||
)
|
||||
@@ -45,6 +45,7 @@ PRIVATE_MODELS = [
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for not tested
|
||||
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
||||
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
||||
"PLBartEncoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoder", # Building part of bigger (tested) model.
|
||||
@@ -95,6 +96,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
|
||||
# trigger the common tests.
|
||||
TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"decision_transformer/test_modeling_decision_transformer.py",
|
||||
"camembert/test_modeling_camembert.py",
|
||||
"mt5/test_modeling_flax_mt5.py",
|
||||
"mbart/test_modeling_mbart.py",
|
||||
@@ -108,12 +110,14 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"xlm_roberta/test_modeling_xlm_roberta.py",
|
||||
"vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
|
||||
"vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
|
||||
"decision_transformer/test_modeling_decision_transformer.py",
|
||||
]
|
||||
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"DecisionTransformerGPT2Model",
|
||||
"GLPNForDepthEstimation",
|
||||
"ViltForQuestionAnswering",
|
||||
"ViltForImagesAndTextClassification",
|
||||
|
||||
Reference in New Issue
Block a user