Update expected values (after switching to A10) - part 3 (#39179)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
from transformers import SwitchTransformersConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
@@ -1035,18 +1036,28 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MEAN_LOGITS = torch.Tensor(
|
||||
[
|
||||
-0.204102, -0.193359, 0.523438, -0.296875, 0.108887,
|
||||
0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875,
|
||||
0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445,
|
||||
0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883,
|
||||
0.390625, -0.203125, -0.122559, -0.180664, 0.0437012,
|
||||
-0.349609, -0.0250244, -0.104004, -0.15918, -0.133789
|
||||
]
|
||||
).to(torch.bfloat16)
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): [
|
||||
-0.204102, -0.193359, 0.523438, -0.296875, 0.108887,
|
||||
0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875,
|
||||
0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445,
|
||||
0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883,
|
||||
0.390625, -0.203125, -0.122559, -0.180664, 0.0437012,
|
||||
-0.349609, -0.0250244, -0.104004, -0.15918, -0.133789
|
||||
],
|
||||
("cuda", 8): [
|
||||
-0.2051, -0.1914, 0.5352, -0.2988, 0.1108, 0.0200, 0.6094, -0.1025,
|
||||
-0.0549, 0.2988, -0.0018, 0.1758, 0.1348, -0.1689, -0.1035, 0.0266,
|
||||
0.0383, 0.0493, -0.2119, 0.1328, 0.3906, -0.2041, -0.1240, -0.1836,
|
||||
0.0454, -0.3477, -0.0256, -0.1050, -0.1572, -0.1338
|
||||
],
|
||||
}
|
||||
)
|
||||
EXPECTED_MEAN_LOGITS = torch.tensor(expectations.get_expectation()).to(torch_device, dtype=torch.bfloat16)
|
||||
# fmt: on
|
||||
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
|
||||
|
||||
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
hf_logits = hf_logits[0, 0, :30]
|
||||
|
||||
torch.testing.assert_close(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
|
||||
|
||||
Reference in New Issue
Block a user