Update expected values (after switching to A10) - part 3 (#39179)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-07-02 22:48:30 +02:00
committed by GitHub
parent 9326fc332d
commit 37a239ca50
12 changed files with 181 additions and 68 deletions

View File

@@ -19,6 +19,7 @@ import unittest
from transformers import SwitchTransformersConfig, is_torch_available
from transformers.testing_utils import (
Expectations,
require_tokenizers,
require_torch,
require_torch_accelerator,
@@ -1035,18 +1036,28 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
# fmt: off
EXPECTED_MEAN_LOGITS = torch.Tensor(
[
-0.204102, -0.193359, 0.523438, -0.296875, 0.108887,
0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875,
0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445,
0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883,
0.390625, -0.203125, -0.122559, -0.180664, 0.0437012,
-0.349609, -0.0250244, -0.104004, -0.15918, -0.133789
]
).to(torch.bfloat16)
expectations = Expectations(
{
(None, None): [
-0.204102, -0.193359, 0.523438, -0.296875, 0.108887,
0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875,
0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445,
0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883,
0.390625, -0.203125, -0.122559, -0.180664, 0.0437012,
-0.349609, -0.0250244, -0.104004, -0.15918, -0.133789
],
("cuda", 8): [
-0.2051, -0.1914, 0.5352, -0.2988, 0.1108, 0.0200, 0.6094, -0.1025,
-0.0549, 0.2988, -0.0018, 0.1758, 0.1348, -0.1689, -0.1035, 0.0266,
0.0383, 0.0493, -0.2119, 0.1328, 0.3906, -0.2041, -0.1240, -0.1836,
0.0454, -0.3477, -0.0256, -0.1050, -0.1572, -0.1338
],
}
)
EXPECTED_MEAN_LOGITS = torch.tensor(expectations.get_expectation()).to(torch_device, dtype=torch.bfloat16)
# fmt: on
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
hf_logits = hf_logits[0, 0, :30]
torch.testing.assert_close(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)