@@ -525,7 +525,9 @@ class GraniteMoeIntegrationTest(unittest.TestCase):
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]])
|
||||
|
||||
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
|
||||
self.assertTrue(
|
||||
torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), atol=1e-2, rtol=1e-2)
|
||||
)
|
||||
|
||||
# slicing logits[0, 0, 0:15]
|
||||
EXPECTED_SLICE = torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892,
|
||||
@@ -535,7 +537,7 @@ class GraniteMoeIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
EXPECTED_SLICE.to(torch_device),
|
||||
out.logits[0, 0, :15],
|
||||
out.logits[0, 0, :15].float(),
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user