add test for Jamba with new model jamba-tiny-dev (#33863)

* add test for jamba with new model

* ruff fix

---------

Co-authored-by: Yehoshua Cohen <yehoshuaco@ai21.com>
This commit is contained in:
Yehoshua Cohen
2024-10-05 17:03:12 +03:00
committed by GitHub
parent a7e4e1a77c
commit 56be9f1925

View File

@@ -653,7 +653,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
model_id = "ai21labs/Jamba-tiny-random" model_id = "ai21labs/Jamba-tiny-dev"
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
cls.tokenizer = AutoTokenizer.from_pretrained(model_id) cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
if is_torch_available() and torch.cuda.is_available(): if is_torch_available() and torch.cuda.is_available():
@@ -668,7 +668,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
# considering differences in hardware processing and potential deviations in generated text. # considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXTS = { EXPECTED_TEXTS = {
7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", 7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
8: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb", 8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb", 9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
} }
@@ -688,11 +688,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
EXPECTED_LOGITS_NO_GRAD = torch.tensor( EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[ [
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660, -7.6875, -7.6562, 8.9375, -7.7812, -7.4062, -7.9688, -8.3125, -7.4062,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279, -7.8125, -8.1250, -7.8125, -7.3750, -7.8438, -7.5000, -8.0625, -8.0625,
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292, -7.5938, -7.9688, -8.2500, -7.5625, -7.7500, -7.7500, -7.6562, -7.6250,
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906, -8.1250, -8.0625, -8.1250, -7.8750, -8.1875, -8.2500, -7.5938, -8.0000,
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562 -7.5000, -7.7500, -7.9375, -7.4688, -8.0625, -7.3438, -8.0000, -7.5000
] ]
, dtype=torch.float32) # fmt: skip , dtype=torch.float32) # fmt: skip
@@ -710,8 +710,8 @@ class JambaModelIntegrationTest(unittest.TestCase):
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
], ],
8: [ 8: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",
], ],
9: [ 9: [
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
@@ -737,21 +737,21 @@ class JambaModelIntegrationTest(unittest.TestCase):
# TODO fix logits # TODO fix logits
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[ [
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641, -7.7188, -7.6875, 8.8750, -7.8125, -7.4062, -8.0000, -8.3125, -7.4375,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289, -7.8125, -8.1250, -7.8125, -7.4062, -7.8438, -7.5312, -8.0625, -8.0625,
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261, -7.6250, -8.0000, -8.3125, -7.5938, -7.7500, -7.7500, -7.6562, -7.6562,
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945, -8.1250, -8.0625, -8.1250, -7.8750, -8.1875, -8.2500, -7.5938, -8.0625,
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583 -7.5000, -7.7812, -7.9375, -7.4688, -8.0625, -7.3750, -8.0000, -7.50003
] ]
, dtype=torch.float32) # fmt: skip , dtype=torch.float32) # fmt: skip
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[ [
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874, -3.5469, -4.0625, 8.5000, -3.8125, -3.6406, -3.7969, -3.8125, -3.3594,
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852, -3.7188, -3.7500, -3.7656, -3.5469, -3.7969, -4.0000, -3.5625, -3.6406,
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129, -3.7188, -3.6094, -4.0938, -3.6719, -3.8906, -3.9844, -3.8594, -3.4219,
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891, -3.2031, -3.4375, -3.7500, -3.6562, -3.9688, -4.1250, -3.6406, -3.57811,
0.2363, 0.2656, 0.0488, -0.1875, 0.2148, -0.1250, 0.1816, 0.0077 -3.0312, -3.4844, -3.6094, -3.5938, -3.7656, -3.8125, -3.7500, -3.8594
] ]
, dtype=torch.float32) # fmt: skip , dtype=torch.float32) # fmt: skip