Fix some tests (especially compile with fullgraph=True on Python<3.11) (#38319)
* fix tests * better fix for python<3.11 * fixes * style
This commit is contained in:
@@ -232,8 +232,8 @@ class CohereIntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_LOGITS = torch.Tensor(
|
||||
[
|
||||
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]],
|
||||
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]],
|
||||
[[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]],
|
||||
[[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]],
|
||||
]
|
||||
).to(device=torch_device, dtype=torch.float16)
|
||||
|
||||
@@ -251,4 +251,4 @@ class CohereIntegrationTest(unittest.TestCase):
|
||||
output = model(**inputs)
|
||||
|
||||
logits = output.logits
|
||||
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user