Decision transformer gym (#15845)
* Created the Decision Transformer Modle * updating tests, copy to other machine * Added last hidden size to Decision Transformer modelling outputs * Removed copy of original DT file * made a temporary change to gpt2 to have it conform with the Decision Transformer version * Updated tests * Ignoring a file used to test the DT model * added comments to config file * added comments and argument descriptions to decision transformer file * Updated doc * Ran "make style" * Remove old model imports * Removed unused imports, cleaned up init file * Update docs/source/model_doc/decision_transformer.mdx added my username Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Reverted changes made to gpt2 * Removed datasets submodule * Update the modeling outputs to include gpt2 attentions, hidden states and last hidden states * Added support for return of hidden states, attentions and return dict of gpt2 model. * Updated tests to include many of the ModelTesterMixin tests. The following tests are skipped: test_generate_without_input_ids, test_pruning, test_resize_embeddings, test_head_masking, test_attention_outputs, test_hidden_states_output, test_inputs_embeds, test_model_common_attributes * Added missing line to the end of gpt2 file * Added an integration test for the Decision Transformer Test performs and autoregressive evaluation for two time steps * Set done and info to _ to fix failing test * Updated integration test to be deterministic and check expected outputs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unnecessary config options * Cleaned up commented code and old comments. * Cleaned up commented code. * Changed DecisionTransformer to Decision Transformer * Added Decision Transformer to the main README file * Added copy of GTP2 called DecisionTranformerGPT2Model * isorted imports * isorted imports * Added model to non-English README files * Ran make fix-copies and corrected some cases. * Updated index file to include Decision Transformer * Added gpt2 model as copy inside the Decision Transformer model file * Added the unit test file to the list of TEST_FILES_WITH_NO_COMMON_TESTS * Deleted redundant checkpoint files (I don't know how these got committed) * Removed testing files. (These should have never been committed) * Removed accidentally committed files * Moved the Decision Transformer test to its own directory * Add type hints for Pegasus (#16324) * Funnel type hints (#16323) * add pt funnel type hints * add tf funnel type hints * Add type hints for ProphetNet PyTorch (#16272) * [GLPN] Improve docs (#16331) * Add link to notebook * Add link * Fix bug Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> * Added type hints for Pytorch Marian calls (#16200) * Added type hinting for forward functions in pytorch marian * typo correction * Removed type hints on functions from BART per Suraj Patil request * fix import pb * fix typo * corrected tuple call * ran black * after fix-copies Some optional tags on primitives were removed, past_key_values in MarianForCausalLM changed from Tuple of Tuple to List * Fixing copies to roformer and pegasus Co-authored-by: Clementine Fourrier <cfourrie@inria.fr> Co-authored-by: matt <rocketknight1@gmail.com> * Moved DecisionTransformOutput to modeling_decision_transformer * Moved the example usage to research project and cleaned comments * Made tests ignore the copy of gpt2 in Decision Transformer * Added module output to modelling decision transformer * removed copied gpt2 model from list of transformers models * Updated tests and created __init__ file for new test location * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unneeded summary type from config file * Fixed copies * Updated pretrained config map to refer to hopper-medium checkpoint * done (#16340) * Added Decision transformer to model docs * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add type annotations for Rembert/Splinter and copies (#16338) * undo black autoformat * minor fix to rembert forward with default * make fix-copies, make quality * Adding types to template model * Removing List from the template types * Remove `Optional` from a couple of types that don't accept `None` Co-authored-by: matt <rocketknight1@gmail.com> * [Bug template] Shift responsibilities for long-range (#16344) * Fix code repetition in serialization guide (#16346) * Adopt framework-specific blocks for content (#16342) * ✨ refactor code samples with framework-specific blocks * ✨ update training.mdx * 🖍 apply feedback * Updates the default branch from master to main (#16326) * Updates the default branch from master to main * Links from `master` to `main` * Typo * Update examples/flax/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Updated model with custom docstring example * Created the Decision Transformer Modle * updating tests, copy to other machine * Added last hidden size to Decision Transformer modelling outputs * Removed copy of original DT file * made a temporary change to gpt2 to have it conform with the Decision Transformer version * Updated tests * Ignoring a file used to test the DT model * added comments to config file * added comments and argument descriptions to decision transformer file * Updated doc * Ran "make style" * Remove old model imports * Removed unused imports, cleaned up init file * Update docs/source/model_doc/decision_transformer.mdx added my username Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Reverted changes made to gpt2 * Removed datasets submodule * Update the modeling outputs to include gpt2 attentions, hidden states and last hidden states * Added support for return of hidden states, attentions and return dict of gpt2 model. * Updated tests to include many of the ModelTesterMixin tests. The following tests are skipped: test_generate_without_input_ids, test_pruning, test_resize_embeddings, test_head_masking, test_attention_outputs, test_hidden_states_output, test_inputs_embeds, test_model_common_attributes * Added missing line to the end of gpt2 file * Added an integration test for the Decision Transformer Test performs and autoregressive evaluation for two time steps * Set done and info to _ to fix failing test * Updated integration test to be deterministic and check expected outputs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unnecessary config options * Cleaned up commented code and old comments. * Cleaned up commented code. * Changed DecisionTransformer to Decision Transformer * Added Decision Transformer to the main README file * Added copy of GTP2 called DecisionTranformerGPT2Model * isorted imports * isorted imports * Added model to non-English README files * Ran make fix-copies and corrected some cases. * Updated index file to include Decision Transformer * Added gpt2 model as copy inside the Decision Transformer model file * Added the unit test file to the list of TEST_FILES_WITH_NO_COMMON_TESTS * Deleted redundant checkpoint files (I don't know how these got committed) * Removed testing files. (These should have never been committed) * Removed accidentally committed files * Moved the Decision Transformer test to its own directory * Moved DecisionTransformOutput to modeling_decision_transformer * Moved the example usage to research project and cleaned comments * Made tests ignore the copy of gpt2 in Decision Transformer * Added module output to modelling decision transformer * removed copied gpt2 model from list of transformers models * Updated tests and created __init__ file for new test location * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Removed unneeded summary type from config file * Fixed copies * Updated pretrained config map to refer to hopper-medium checkpoint * Added Decision transformer to model docs * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/modeling_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/decision_transformer/configuration_decision_transformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Updated model with custom docstring example * Updated copies, config auto, and readme files. Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Dan Tegzes <48134725+Tegzes@users.noreply.github.com> Co-authored-by: Adam Montgomerie <adam@avanssion.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> Co-authored-by: Clementine Fourrier <cfourrie@inria.fr> Co-authored-by: matt <rocketknight1@gmail.com> Co-authored-by: Francesco Saverio Zuppichini <francesco.zuppichini@gmail.com> Co-authored-by: Jacob Dineen <54680234+jacobdineen@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
240
examples/research_projects/decision_transformer/requirements.txt
Normal file
240
examples/research_projects/decision_transformer/requirements.txt
Normal file
@@ -0,0 +1,240 @@
|
||||
absl-py==1.0.0
|
||||
aiohttp==3.8.1
|
||||
aiosignal==1.2.0
|
||||
alembic==1.7.7
|
||||
appdirs==1.4.4
|
||||
APScheduler==3.9.1
|
||||
arrow==1.2.2
|
||||
asttokens==2.0.5
|
||||
astunparse==1.6.3
|
||||
async-timeout==4.0.2
|
||||
attrs==21.4.0
|
||||
audioread==2.1.9
|
||||
autopage==0.5.0
|
||||
backcall==0.2.0
|
||||
backoff==1.11.1
|
||||
backports.zoneinfo==0.2.1
|
||||
binaryornot==0.4.4
|
||||
black==22.1.0
|
||||
boto3==1.16.34
|
||||
botocore==1.19.63
|
||||
Brotli==1.0.9
|
||||
cachetools==5.0.0
|
||||
certifi==2021.10.8
|
||||
cffi==1.15.0
|
||||
chardet==4.0.0
|
||||
charset-normalizer==2.0.12
|
||||
chex==0.1.1
|
||||
click==8.0.4
|
||||
cliff==3.10.1
|
||||
clldutils==3.11.1
|
||||
cloudpickle==2.0.0
|
||||
cmaes==0.8.2
|
||||
cmd2==2.4.0
|
||||
codecarbon==1.2.0
|
||||
colorlog==6.6.0
|
||||
cookiecutter==1.7.2
|
||||
cryptography==36.0.2
|
||||
csvw==2.0.0
|
||||
cycler==0.11.0
|
||||
Cython==0.29.28
|
||||
dash==2.3.0
|
||||
dash-bootstrap-components==1.0.3
|
||||
dash-core-components==2.0.0
|
||||
dash-html-components==2.0.0
|
||||
dash-table==5.0.0
|
||||
datasets==2.0.0
|
||||
decorator==5.1.1
|
||||
Deprecated==1.2.13
|
||||
dill==0.3.4
|
||||
dlinfo==1.2.1
|
||||
dm-tree==0.1.6
|
||||
docker==4.4.4
|
||||
execnet==1.9.0
|
||||
executing==0.8.3
|
||||
faiss-cpu==1.7.2
|
||||
fasteners==0.17.3
|
||||
filelock==3.6.0
|
||||
fire==0.4.0
|
||||
flake8==4.0.1
|
||||
Flask==2.0.3
|
||||
Flask-Compress==1.11
|
||||
flatbuffers==2.0
|
||||
flax==0.4.0
|
||||
fonttools==4.31.1
|
||||
frozenlist==1.3.0
|
||||
fsspec==2022.2.0
|
||||
fugashi==1.1.2
|
||||
gast==0.5.3
|
||||
gitdb==4.0.9
|
||||
GitPython==3.1.18
|
||||
glfw==2.5.1
|
||||
google-auth==2.6.2
|
||||
google-auth-oauthlib==0.4.6
|
||||
google-pasta==0.2.0
|
||||
greenlet==1.1.2
|
||||
grpcio==1.44.0
|
||||
gym==0.23.1
|
||||
gym-notices==0.0.6
|
||||
h5py==3.6.0
|
||||
huggingface-hub==0.4.0
|
||||
hypothesis==6.39.4
|
||||
idna==3.3
|
||||
imageio==2.16.1
|
||||
importlib-metadata==4.11.3
|
||||
importlib-resources==5.4.0
|
||||
iniconfig==1.1.1
|
||||
ipadic==1.0.0
|
||||
ipython==8.1.1
|
||||
isodate==0.6.1
|
||||
isort==5.10.1
|
||||
itsdangerous==2.1.1
|
||||
jax==0.3.4
|
||||
jaxlib==0.3.2
|
||||
jedi==0.18.1
|
||||
Jinja2==2.11.3
|
||||
jinja2-time==0.2.0
|
||||
jmespath==0.10.0
|
||||
joblib==1.1.0
|
||||
jsonschema==4.4.0
|
||||
keras==2.8.0
|
||||
Keras-Preprocessing==1.1.2
|
||||
kiwisolver==1.4.0
|
||||
kubernetes==12.0.1
|
||||
libclang==13.0.0
|
||||
librosa==0.9.1
|
||||
llvmlite==0.38.0
|
||||
Mako==1.2.0
|
||||
Markdown==3.3.6
|
||||
MarkupSafe==1.1.1
|
||||
matplotlib==3.5.1
|
||||
matplotlib-inline==0.1.3
|
||||
mccabe==0.6.1
|
||||
msgpack==1.0.3
|
||||
mujoco-py==2.1.2.14
|
||||
multidict==6.0.2
|
||||
multiprocess==0.70.12.2
|
||||
mypy-extensions==0.4.3
|
||||
nltk==3.7
|
||||
numba==0.55.1
|
||||
numpy==1.22.3
|
||||
oauthlib==3.2.0
|
||||
onnx==1.11.0
|
||||
onnxconverter-common==1.9.0
|
||||
opt-einsum==3.3.0
|
||||
optax==0.1.1
|
||||
optuna==2.10.0
|
||||
packaging==21.3
|
||||
pandas==1.4.1
|
||||
parameterized==0.8.1
|
||||
parso==0.8.3
|
||||
pathspec==0.9.0
|
||||
pbr==5.8.1
|
||||
pexpect==4.8.0
|
||||
phonemizer==3.0.1
|
||||
pickleshare==0.7.5
|
||||
Pillow==9.0.1
|
||||
Pint==0.16.1
|
||||
plac==1.3.4
|
||||
platformdirs==2.5.1
|
||||
plotly==5.6.0
|
||||
pluggy==1.0.0
|
||||
pooch==1.6.0
|
||||
portalocker==2.0.0
|
||||
poyo==0.5.0
|
||||
prettytable==3.2.0
|
||||
prompt-toolkit==3.0.28
|
||||
protobuf==3.19.4
|
||||
psutil==5.9.0
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
py==1.11.0
|
||||
py-cpuinfo==8.0.0
|
||||
pyarrow==7.0.0
|
||||
pyasn1==0.4.8
|
||||
pyasn1-modules==0.2.8
|
||||
pycodestyle==2.8.0
|
||||
pycparser==2.21
|
||||
pyctcdecode==0.3.0
|
||||
pyflakes==2.4.0
|
||||
Pygments==2.11.2
|
||||
pygtrie==2.4.2
|
||||
pynvml==11.4.1
|
||||
pyOpenSSL==22.0.0
|
||||
pyparsing==3.0.7
|
||||
pyperclip==1.8.2
|
||||
pypng==0.0.21
|
||||
pyrsistent==0.18.1
|
||||
pytest==7.1.1
|
||||
pytest-forked==1.4.0
|
||||
pytest-timeout==2.1.0
|
||||
pytest-xdist==2.5.0
|
||||
python-dateutil==2.8.2
|
||||
python-slugify==6.1.1
|
||||
pytz==2022.1
|
||||
pytz-deprecation-shim==0.1.0.post0
|
||||
PyYAML==6.0
|
||||
ray==1.11.0
|
||||
redis==4.1.4
|
||||
regex==2022.3.15
|
||||
requests==2.27.1
|
||||
requests-oauthlib==1.3.1
|
||||
resampy==0.2.2
|
||||
responses==0.18.0
|
||||
rfc3986==1.5.0
|
||||
rouge-score==0.0.4
|
||||
rsa==4.8
|
||||
s3transfer==0.3.7
|
||||
sacrebleu==1.5.1
|
||||
sacremoses==0.0.49
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.8.0
|
||||
segments==2.2.0
|
||||
sentencepiece==0.1.96
|
||||
sigopt==8.2.0
|
||||
six==1.16.0
|
||||
smmap==5.0.0
|
||||
sortedcontainers==2.4.0
|
||||
SoundFile==0.10.3.post1
|
||||
SQLAlchemy==1.4.32
|
||||
stack-data==0.2.0
|
||||
stevedore==3.5.0
|
||||
tabulate==0.8.9
|
||||
tenacity==8.0.1
|
||||
tensorboard==2.8.0
|
||||
tensorboard-data-server==0.6.1
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
tensorboardX==2.5
|
||||
tensorflow==2.8.0
|
||||
tensorflow-io-gcs-filesystem==0.24.0
|
||||
termcolor==1.1.0
|
||||
text-unidecode==1.3
|
||||
tf-estimator-nightly==2.8.0.dev2021122109
|
||||
tf2onnx==1.9.3
|
||||
threadpoolctl==3.1.0
|
||||
timeout-decorator==0.5.0
|
||||
timm==0.5.4
|
||||
tokenizers==0.11.6
|
||||
tomli==2.0.1
|
||||
toolz==0.11.2
|
||||
torch==1.11.0
|
||||
torchaudio==0.11.0
|
||||
torchvision==0.12.0
|
||||
tqdm==4.63.0
|
||||
traitlets==5.1.1
|
||||
-e git+git@github.com:edbeeching/transformers.git@77b90113ca0a0e4058b046796c874bdc98f1da61#egg=transformers
|
||||
typing-extensions==4.1.1
|
||||
tzdata==2022.1
|
||||
tzlocal==4.1
|
||||
unidic==1.1.0
|
||||
unidic-lite==1.0.8
|
||||
uritemplate==4.1.1
|
||||
urllib3==1.26.9
|
||||
wasabi==0.9.0
|
||||
wcwidth==0.2.5
|
||||
websocket-client==1.3.1
|
||||
Werkzeug==2.0.3
|
||||
wrapt==1.14.0
|
||||
xxhash==3.0.0
|
||||
yarl==1.7.2
|
||||
zipp==3.7.0
|
||||
@@ -0,0 +1,173 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import gym
|
||||
from mujoco_py import GlfwContext
|
||||
from transformers import DecisionTransformerModel
|
||||
|
||||
|
||||
GlfwContext(offscreen=True) # Create a window to init GLFW.
|
||||
|
||||
|
||||
def get_action(model, states, actions, rewards, returns_to_go, timesteps):
|
||||
# we don't care about the past rewards in this model
|
||||
|
||||
states = states.reshape(1, -1, model.config.state_dim)
|
||||
actions = actions.reshape(1, -1, model.config.act_dim)
|
||||
returns_to_go = returns_to_go.reshape(1, -1, 1)
|
||||
timesteps = timesteps.reshape(1, -1)
|
||||
|
||||
if model.config.max_length is not None:
|
||||
states = states[:, -model.config.max_length :]
|
||||
actions = actions[:, -model.config.max_length :]
|
||||
returns_to_go = returns_to_go[:, -model.config.max_length :]
|
||||
timesteps = timesteps[:, -model.config.max_length :]
|
||||
|
||||
# pad all tokens to sequence length
|
||||
attention_mask = torch.cat(
|
||||
[torch.zeros(model.config.max_length - states.shape[1]), torch.ones(states.shape[1])]
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
|
||||
states = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(states.shape[0], model.config.max_length - states.shape[1], model.config.state_dim),
|
||||
device=states.device,
|
||||
),
|
||||
states,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
actions = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(actions.shape[0], model.config.max_length - actions.shape[1], model.config.act_dim),
|
||||
device=actions.device,
|
||||
),
|
||||
actions,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
returns_to_go = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(returns_to_go.shape[0], model.config.max_length - returns_to_go.shape[1], 1),
|
||||
device=returns_to_go.device,
|
||||
),
|
||||
returns_to_go,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.float32)
|
||||
timesteps = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(timesteps.shape[0], model.config.max_length - timesteps.shape[1]), device=timesteps.device
|
||||
),
|
||||
timesteps,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype=torch.long)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
_, action_preds, _ = model(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
returns_to_go=returns_to_go,
|
||||
timesteps=timesteps,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
return action_preds[0, -1]
|
||||
|
||||
|
||||
# build the environment
|
||||
|
||||
env = gym.make("Hopper-v3")
|
||||
state_dim = env.observation_space.shape[0]
|
||||
act_dim = env.action_space.shape[0]
|
||||
max_ep_len = 1000
|
||||
device = "cuda"
|
||||
scale = 1000.0 # normalization for rewards/returns
|
||||
TARGET_RETURN = 3600 / scale # evaluation conditioning targets, 3600 is reasonable from the paper LINK
|
||||
state_mean = np.array(
|
||||
[
|
||||
1.311279,
|
||||
-0.08469521,
|
||||
-0.5382719,
|
||||
-0.07201576,
|
||||
0.04932366,
|
||||
2.1066856,
|
||||
-0.15017354,
|
||||
0.00878345,
|
||||
-0.2848186,
|
||||
-0.18540096,
|
||||
-0.28461286,
|
||||
]
|
||||
)
|
||||
state_std = np.array(
|
||||
[
|
||||
0.17790751,
|
||||
0.05444621,
|
||||
0.21297139,
|
||||
0.14530419,
|
||||
0.6124444,
|
||||
0.85174465,
|
||||
1.4515252,
|
||||
0.6751696,
|
||||
1.536239,
|
||||
1.6160746,
|
||||
5.6072536,
|
||||
]
|
||||
)
|
||||
state_mean = torch.from_numpy(state_mean).to(device=device)
|
||||
state_std = torch.from_numpy(state_std).to(device=device)
|
||||
|
||||
# Create the decision transformer model
|
||||
model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
for ep in range(10):
|
||||
episode_return, episode_length = 0, 0
|
||||
state = env.reset()
|
||||
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
|
||||
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
|
||||
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
|
||||
rewards = torch.zeros(0, device=device, dtype=torch.float32)
|
||||
|
||||
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
|
||||
for t in range(max_ep_len):
|
||||
env.render()
|
||||
# add padding
|
||||
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
|
||||
rewards = torch.cat([rewards, torch.zeros(1, device=device)])
|
||||
|
||||
action = get_action(
|
||||
model,
|
||||
(states.to(dtype=torch.float32) - state_mean) / state_std,
|
||||
actions.to(dtype=torch.float32),
|
||||
rewards.to(dtype=torch.float32),
|
||||
target_return.to(dtype=torch.float32),
|
||||
timesteps.to(dtype=torch.long),
|
||||
)
|
||||
actions[-1] = action
|
||||
action = action.detach().cpu().numpy()
|
||||
|
||||
state, reward, done, _ = env.step(action)
|
||||
|
||||
cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
|
||||
states = torch.cat([states, cur_state], dim=0)
|
||||
rewards[-1] = reward
|
||||
|
||||
pred_return = target_return[0, -1] - (reward / scale)
|
||||
target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
|
||||
timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)
|
||||
|
||||
episode_return += reward
|
||||
episode_length += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
Reference in New Issue
Block a user