add test for Jamba with new model jamba-tiny-dev (#33863)
* add test for jamba with new model * ruff fix --------- Co-authored-by: Yehoshua Cohen <yehoshuaco@ai21.com>
This commit is contained in:
@@ -653,7 +653,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model_id = "ai21labs/Jamba-tiny-random"
|
||||
model_id = "ai21labs/Jamba-tiny-dev"
|
||||
cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
@@ -668,7 +668,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||
8: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
|
||||
8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
||||
9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb",
|
||||
}
|
||||
|
||||
@@ -688,11 +688,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||
[
|
||||
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
|
||||
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
|
||||
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
|
||||
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
|
||||
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
|
||||
-7.6875, -7.6562, 8.9375, -7.7812, -7.4062, -7.9688, -8.3125, -7.4062,
|
||||
-7.8125, -8.1250, -7.8125, -7.3750, -7.8438, -7.5000, -8.0625, -8.0625,
|
||||
-7.5938, -7.9688, -8.2500, -7.5625, -7.7500, -7.7500, -7.6562, -7.6250,
|
||||
-8.1250, -8.0625, -8.1250, -7.8750, -8.1875, -8.2500, -7.5938, -8.0000,
|
||||
-7.5000, -7.7500, -7.9375, -7.4688, -8.0625, -7.3438, -8.0000, -7.5000
|
||||
]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
|
||||
@@ -710,8 +710,8 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
|
||||
],
|
||||
8: [
|
||||
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",
|
||||
"<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.",
|
||||
"<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",
|
||||
],
|
||||
9: [
|
||||
"<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas",
|
||||
@@ -737,21 +737,21 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
# TODO fix logits
|
||||
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
||||
[
|
||||
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
|
||||
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
|
||||
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
|
||||
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
|
||||
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
|
||||
-7.7188, -7.6875, 8.8750, -7.8125, -7.4062, -8.0000, -8.3125, -7.4375,
|
||||
-7.8125, -8.1250, -7.8125, -7.4062, -7.8438, -7.5312, -8.0625, -8.0625,
|
||||
-7.6250, -8.0000, -8.3125, -7.5938, -7.7500, -7.7500, -7.6562, -7.6562,
|
||||
-8.1250, -8.0625, -8.1250, -7.8750, -8.1875, -8.2500, -7.5938, -8.0625,
|
||||
-7.5000, -7.7812, -7.9375, -7.4688, -8.0625, -7.3750, -8.0000, -7.50003
|
||||
]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||
[
|
||||
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
|
||||
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
|
||||
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
|
||||
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
|
||||
0.2363, 0.2656, 0.0488, -0.1875, 0.2148, -0.1250, 0.1816, 0.0077
|
||||
-3.5469, -4.0625, 8.5000, -3.8125, -3.6406, -3.7969, -3.8125, -3.3594,
|
||||
-3.7188, -3.7500, -3.7656, -3.5469, -3.7969, -4.0000, -3.5625, -3.6406,
|
||||
-3.7188, -3.6094, -4.0938, -3.6719, -3.8906, -3.9844, -3.8594, -3.4219,
|
||||
-3.2031, -3.4375, -3.7500, -3.6562, -3.9688, -4.1250, -3.6406, -3.57811,
|
||||
-3.0312, -3.4844, -3.6094, -3.5938, -3.7656, -3.8125, -3.7500, -3.8594
|
||||
]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
|
||||
|
||||
Reference in New Issue
Block a user