Expectations test utils (#36569)
* Add expectation classes + tests * Use typing Union instead of | * Use bits to track score in properties cmp method * Add exceptions and tests + comments * Remove compute cap minor as it is not needed currently * Simplify. Remove Properties class * Add example Exceptions usage * Expectations as dict subclass * Update example Exceptions usage * Refactor. Improve type name. Document score fn. * Rename to DeviceProperties.
This commit is contained in:
@@ -20,12 +20,7 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.testing_utils import Expectations, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -503,15 +498,18 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
def test_simple_generate(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
# 7: "",
|
||||
8: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
|
||||
9: "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
}
|
||||
expectations = Expectations(
|
||||
{
|
||||
(
|
||||
"cuda",
|
||||
8,
|
||||
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.",
|
||||
(
|
||||
"rocm",
|
||||
9,
|
||||
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
}
|
||||
)
|
||||
|
||||
self.model.to(torch_device)
|
||||
|
||||
@@ -520,7 +518,8 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
].to(torch_device)
|
||||
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(out[0, :])
|
||||
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
expected = expectations.get_expectation()
|
||||
self.assertEqual(output_sentence, expected)
|
||||
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
|
||||
Reference in New Issue
Block a user