Fix some tests (especially compile with fullgraph=True on Python<3.11) (#38319)

* fix tests

* better fix for python<3.11

* fixes

* style
This commit is contained in:
Cyril Vallez
2025-05-23 17:11:40 +02:00
committed by GitHub
parent a63bc17416
commit 896833c183
7 changed files with 48 additions and 83 deletions

View File

@@ -232,8 +232,8 @@ class CohereIntegrationTest(unittest.TestCase):
EXPECTED_LOGITS = torch.Tensor(
[
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]],
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]],
[[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]],
[[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]],
]
).to(device=torch_device, dtype=torch.float16)
@@ -251,4 +251,4 @@ class CohereIntegrationTest(unittest.TestCase):
output = model(**inputs)
logits = output.logits
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3)
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3)