[GPTNeoX] Fix BC issue with 4.36 (#28602)
* fix dtype issue * add a test * update copied from mentions * nits * fixup * fix copies * Apply suggestions from code review
This commit is contained in:
@@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||
output_str = tokenizer.batch_decode(output_ids)[0]
|
||||
|
||||
self.assertEqual(output_str, expected_output)
|
||||
|
||||
def pythia_integration_test(self):
|
||||
model_name_or_path = "EleutherAI/pythia-70m"
|
||||
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
|
||||
EXPECTED_LOGITS = torch.tensor([1069.0000, 228.7500, 1072.0000, 1072.0000, 1069.0000, 1068.0000, 1068.0000, 1071.0000, 1071.0000, 1071.0000, 1073.0000, 1070.0000, 1071.0000, 1075.0000, 1073.0000, 1075.0000, 1074.0000, 1069.0000, 1072.0000, 1071.0000, 1071.0000, 1071.0000, 1070.0000, 1069.0000, 1069.0000, 1069.0000, 1070.0000, 1075.0000, 1073.0000, 1074.0000]) # fmt: skip
|
||||
input_ids = [29, 93, 303, 64, 5478, 49651, 10394, 187, 34, 12939, 875]
|
||||
# alternative: tokenizer('<|im_start|>system\nA chat between')
|
||||
input_ids = torch.as_tensor(input_ids)[None].to(torch_device)
|
||||
outputs = model(input_ids)["logits"][:, -1][0, :30]
|
||||
self.assertTrue(torch.allclose(EXPECTED_LOGITS, outputs, atol=1e-5))
|
||||
|
||||
Reference in New Issue
Block a user