Fix TVLT (torch device issue) (#21710)

* fix tvlt ci

* fix tvlt ci

* fix tvlt ci

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-21 11:37:49 +01:00
committed by GitHub
parent 4c6346cc3e
commit 03aaac3502
2 changed files with 7 additions and 4 deletions

View File

@@ -590,7 +590,7 @@ class TvltModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs)
# verify the logits
expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]])
expected_last_hidden_state_slice = torch.tensor([[-0.0186, -0.0691], [0.0242, -0.0398]], device=torch_device)
self.assertTrue(
torch.allclose(outputs.last_hidden_state[:, :2, :2], expected_last_hidden_state_slice, atol=1e-4)
)