switch to device agnostic device calling for test cases (#38247)

* use device agnostic APIs in test cases

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* add one more

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xpu now supports integer device id, aligning to CUDA behaviors

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* update to use device_properties

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* update comment

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix comments

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix
2025-05-26 16:18:53 +08:00
committed by GitHub
parent cba279f46c
commit a5a0c7b888
39 changed files with 259 additions and 389 deletions

View File

@@ -21,6 +21,8 @@ import pytest
from transformers import AutoTokenizer, JambaConfig, is_torch_available
from transformers.testing_utils import (
Expectations,
get_device_properties,
require_bitsandbytes,
require_flash_attn,
require_torch,
@@ -554,30 +556,32 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
class JambaModelIntegrationTest(unittest.TestCase):
model = None
tokenizer = None
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
device_properties = None
@classmethod
def setUpClass(cls):
model_id = "ai21labs/Jamba-tiny-dev"
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
cls.device_properties = get_device_properties()
@slow
def test_simple_generate(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXTS = {
7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
}
# fmt: off
EXPECTED_TEXTS = Expectations(
{
("cuda", 7): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
("cuda", 8): "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
("rocm", 9): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
}
)
# fmt: on
expected_sentence = EXPECTED_TEXTS.get_expectation()
self.model.to(torch_device)
@@ -586,10 +590,10 @@ class JambaModelIntegrationTest(unittest.TestCase):
].to(torch_device)
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
output_sentence = self.tokenizer.decode(out[0, :])
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
self.assertEqual(output_sentence, expected_sentence)
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.cuda_compute_capability_major_version == 8:
if self.device_properties == ("cuda", 8):
with torch.no_grad():
logits = self.model(input_ids=input_ids).logits
@@ -607,24 +611,19 @@ class JambaModelIntegrationTest(unittest.TestCase):
@slow
def test_simple_batched_generate_with_padding(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXTS = {
7: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
],
8: [
"<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",
],
9: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
],
}
# fmt: off
EXPECTED_TEXTS = Expectations(
{
("cuda", 7): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",],
("cuda", 8): ["<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",],
("rocm", 9): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",],
}
)
# fmt: on
expected_sentences = EXPECTED_TEXTS.get_expectation()
self.model.to(torch_device)
@@ -633,11 +632,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
).to(torch_device)
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
output_sentences = self.tokenizer.batch_decode(out)
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
self.assertEqual(output_sentences[0], expected_sentences[0])
self.assertEqual(output_sentences[1], expected_sentences[1])
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.cuda_compute_capability_major_version == 8:
if self.device_properties == ("cuda", 8):
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits