[GPTNeoX] Fix BC issue with 4.36 (#28602)

* fix dtype issue

* add a test

* update copied from mentions

* nits

* fixup

* fix copies

* Apply suggestions from code review
This commit is contained in:
Arthur
2024-01-21 18:01:19 +01:00
committed by GitHub
parent 3f69f415ad
commit 83f9196cc4
3 changed files with 26 additions and 15 deletions

View File

@@ -355,3 +355,13 @@ class GPTNeoXLanguageGenerationTest(unittest.TestCase):
output_str = tokenizer.batch_decode(output_ids)[0]
self.assertEqual(output_str, expected_output)
def pythia_integration_test(self):
model_name_or_path = "EleutherAI/pythia-70m"
model = GPTNeoXForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16).to(torch_device)
EXPECTED_LOGITS = torch.tensor([1069.0000, 228.7500, 1072.0000, 1072.0000, 1069.0000, 1068.0000, 1068.0000, 1071.0000, 1071.0000, 1071.0000, 1073.0000, 1070.0000, 1071.0000, 1075.0000, 1073.0000, 1075.0000, 1074.0000, 1069.0000, 1072.0000, 1071.0000, 1071.0000, 1071.0000, 1070.0000, 1069.0000, 1069.0000, 1069.0000, 1070.0000, 1075.0000, 1073.0000, 1074.0000]) # fmt: skip
input_ids = [29, 93, 303, 64, 5478, 49651, 10394, 187, 34, 12939, 875]
# alternative: tokenizer('<|im_start|>system\nA chat between')
input_ids = torch.as_tensor(input_ids)[None].to(torch_device)
outputs = model(input_ids)["logits"][:, -1][0, :30]
self.assertTrue(torch.allclose(EXPECTED_LOGITS, outputs, atol=1e-5))