enable 6 modeling cases on XPU (#37571)
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -520,7 +520,13 @@ class MptIntegrationTests(unittest.TestCase):
|
||||
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
|
||||
expected_slice = torch.Tensor([-0.2520, -0.2178, -0.1953]).to(torch_device, torch.bfloat16)
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.Tensor([-0.2090, -0.2061, -0.1465]),
|
||||
("cuda", 7): torch.Tensor([-0.2520, -0.2178, -0.1953]),
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation().to(torch_device, torch.bfloat16)
|
||||
predicted_slice = outputs.hidden_states[-1][0, 0, :3]
|
||||
|
||||
torch.testing.assert_close(expected_slice, predicted_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
Reference in New Issue
Block a user