CI: AMD MI300 tests fix (#30797)
* add fix * update import * updated dicts and comments * remove prints * Update testing_utils.py
This commit is contained in:
@@ -553,6 +553,10 @@ class MixtralIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||
# these logits have been obtained with the original megablocks impelmentation.
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in output.
|
||||
EXPECTED_LOGITS = {
|
||||
7: torch.Tensor([[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]]).to(
|
||||
torch_device
|
||||
@@ -560,6 +564,9 @@ class MixtralIntegrationTest(unittest.TestCase):
|
||||
8: torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(
|
||||
torch_device
|
||||
),
|
||||
9: torch.Tensor([[0.1641, 0.1621, 0.6094], [-0.8906, -0.1631, -0.6094], [0.1572, 0.1260, 0.7227]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
with torch.no_grad():
|
||||
logits = model(dummy_input).logits
|
||||
@@ -583,6 +590,11 @@ class MixtralIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||
#
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_LOGITS_LEFT = {
|
||||
7: torch.Tensor(
|
||||
[[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]],
|
||||
@@ -590,6 +602,9 @@ class MixtralIntegrationTest(unittest.TestCase):
|
||||
8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to(
|
||||
torch_device
|
||||
),
|
||||
9: torch.Tensor([[0.1904, 0.0513, 0.7227], [0.1943, 0.0518, 0.7227], [0.1982, 0.0557, 0.7148]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
|
||||
EXPECTED_LOGITS_LEFT_UNPADDED = {
|
||||
@@ -599,6 +614,9 @@ class MixtralIntegrationTest(unittest.TestCase):
|
||||
8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to(
|
||||
torch_device
|
||||
),
|
||||
9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
|
||||
EXPECTED_LOGITS_RIGHT_UNPADDED = {
|
||||
@@ -608,6 +626,9 @@ class MixtralIntegrationTest(unittest.TestCase):
|
||||
8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to(
|
||||
torch_device
|
||||
),
|
||||
9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user