switch to device agnostic device calling for test cases (#38247)
* use device agnostic APIs in test cases Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add one more Signed-off-by: YAO Matrix <matrix.yao@intel.com> * xpu now supports integer device id, aligning to CUDA behaviors Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update to use device_properties Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * update comment Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix comments Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,8 @@ import pytest
|
||||
|
||||
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
get_device_properties,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@@ -554,30 +556,32 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
class JambaModelIntegrationTest(unittest.TestCase):
|
||||
model = None
|
||||
tokenizer = None
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
device_properties = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model_id = "ai21labs/Jamba-tiny-dev"
|
||||
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
cls.device_properties = get_device_properties()
|
||||
|
||||
@slow
|
||||
def test_simple_generate(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||
8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
||||
9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
|
||||
}
|
||||
# fmt: off
|
||||
EXPECTED_TEXTS = Expectations(
|
||||
{
|
||||
("cuda", 7): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||
("cuda", 8): "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
||||
("rocm", 9): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
expected_sentence = EXPECTED_TEXTS.get_expectation()
|
||||
|
||||
self.model.to(torch_device)
|
||||
|
||||
@@ -586,10 +590,10 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
].to(torch_device)
|
||||
out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(out[0, :])
|
||||
self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
self.assertEqual(output_sentence, expected_sentence)
|
||||
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
if self.device_properties == ("cuda", 8):
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=input_ids).logits
|
||||
|
||||
@@ -607,24 +611,19 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_simple_batched_generate_with_padding(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
# ("cuda", 8) for A100/A10, and ("cuda", 7) for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats",
|
||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
|
||||
],
|
||||
8: [
|
||||
"<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",
|
||||
],
|
||||
9: [
|
||||
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
|
||||
],
|
||||
}
|
||||
# fmt: off
|
||||
EXPECTED_TEXTS = Expectations(
|
||||
{
|
||||
("cuda", 7): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",],
|
||||
("cuda", 8): ["<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",],
|
||||
("rocm", 9): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",],
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
expected_sentences = EXPECTED_TEXTS.get_expectation()
|
||||
|
||||
self.model.to(torch_device)
|
||||
|
||||
@@ -633,11 +632,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
output_sentences = self.tokenizer.batch_decode(out)
|
||||
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
|
||||
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
|
||||
self.assertEqual(output_sentences[0], expected_sentences[0])
|
||||
self.assertEqual(output_sentences[1], expected_sentences[1])
|
||||
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
if self.device_properties == ("cuda", 8):
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user