Update expected values in DecisionTransformerModelIntegrationTest (#18016)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -206,7 +206,9 @@ class DecisionTransformerModelIntegrationTest(unittest.TestCase):
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
state = torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32) # env.reset()
|
state = torch.randn(1, 1, config.state_dim).to(device=torch_device, dtype=torch.float32) # env.reset()
|
||||||
|
|
||||||
expected_outputs = torch.tensor([[0.2384, -0.2955, 0.8741], [0.6765, -0.0793, -0.1298]], device=torch_device)
|
expected_outputs = torch.tensor(
|
||||||
|
[[0.242793, -0.28693074, 0.8742613], [0.67815274, -0.08101085, -0.12952147]], device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
returns_to_go = torch.tensor(TARGET_RETURN, device=torch_device, dtype=torch.float32).reshape(1, 1, 1)
|
returns_to_go = torch.tensor(TARGET_RETURN, device=torch_device, dtype=torch.float32).reshape(1, 1, 1)
|
||||||
states = state
|
states = state
|
||||||
|
|||||||
Reference in New Issue
Block a user