From 56be9f192501577a579e47757d163afcef2249fe Mon Sep 17 00:00:00 2001 From: Yehoshua Cohen <61619195+yecohn@users.noreply.github.com> Date: Sat, 5 Oct 2024 17:03:12 +0300 Subject: [PATCH] add test for Jamba with new model jamba-tiny-dev (#33863) * add test for jamba with new model * ruff fix --------- Co-authored-by: Yehoshua Cohen --- tests/models/jamba/test_modeling_jamba.py | 38 +++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 6e1a2cf2cf..251f293f72 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -653,7 +653,7 @@ class JambaModelIntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): - model_id = "ai21labs/Jamba-tiny-random" + model_id = "ai21labs/Jamba-tiny-dev" cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) cls.tokenizer = AutoTokenizer.from_pretrained(model_id) if is_torch_available() and torch.cuda.is_available(): @@ -668,7 +668,7 @@ class JambaModelIntegrationTest(unittest.TestCase): # considering differences in hardware processing and potential deviations in generated text. EXPECTED_TEXTS = { 7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", - 8: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb", + 8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", 9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb", } @@ -688,11 +688,11 @@ class JambaModelIntegrationTest(unittest.TestCase): EXPECTED_LOGITS_NO_GRAD = torch.tensor( [ - 0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660, - -0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279, - 0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292, - 0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906, - 0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562 + -7.6875, -7.6562, 8.9375, -7.7812, -7.4062, -7.9688, -8.3125, -7.4062, + -7.8125, -8.1250, -7.8125, -7.3750, -7.8438, -7.5000, -8.0625, -8.0625, + -7.5938, -7.9688, -8.2500, -7.5625, -7.7500, -7.7500, -7.6562, -7.6250, + -8.1250, -8.0625, -8.1250, -7.8750, -8.1875, -8.2500, -7.5938, -8.0000, + -7.5000, -7.7500, -7.9375, -7.4688, -8.0625, -7.3438, -8.0000, -7.5000 ] , dtype=torch.float32) # fmt: skip @@ -710,8 +710,8 @@ class JambaModelIntegrationTest(unittest.TestCase): "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed", ], 8: [ - "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", - "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed", + "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", + "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States", ], 9: [ "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", @@ -737,21 +737,21 @@ class JambaModelIntegrationTest(unittest.TestCase): # TODO fix logits EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( [ - 0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641, - -0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289, - 0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261, - 0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945, - 0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583 + -7.7188, -7.6875, 8.8750, -7.8125, -7.4062, -8.0000, -8.3125, -7.4375, + -7.8125, -8.1250, -7.8125, -7.4062, -7.8438, -7.5312, -8.0625, -8.0625, + -7.6250, -8.0000, -8.3125, -7.5938, -7.7500, -7.7500, -7.6562, -7.6562, + -8.1250, -8.0625, -8.1250, -7.8750, -8.1875, -8.2500, -7.5938, -8.0625, + -7.5000, -7.7812, -7.9375, -7.4688, -8.0625, -7.3750, -8.0000, -7.50003 ] , dtype=torch.float32) # fmt: skip EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( [ - -0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874, - 0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852, - 0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129, - 0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891, - 0.2363, 0.2656, 0.0488, -0.1875, 0.2148, -0.1250, 0.1816, 0.0077 + -3.5469, -4.0625, 8.5000, -3.8125, -3.6406, -3.7969, -3.8125, -3.3594, + -3.7188, -3.7500, -3.7656, -3.5469, -3.7969, -4.0000, -3.5625, -3.6406, + -3.7188, -3.6094, -4.0938, -3.6719, -3.8906, -3.9844, -3.8594, -3.4219, + -3.2031, -3.4375, -3.7500, -3.6562, -3.9688, -4.1250, -3.6406, -3.57811, + -3.0312, -3.4844, -3.6094, -3.5938, -3.7656, -3.8125, -3.7500, -3.8594 ] , dtype=torch.float32) # fmt: skip