From 72256bc72ac2f2e341da47b9aea57e2e37879700 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 12 Oct 2023 11:24:18 +0200 Subject: [PATCH] Fix `PersimmonIntegrationTest` OOM (#26750) * fix --------- Co-authored-by: ydshieh --- .../persimmon/test_modeling_persimmon.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 60a5dabf10..aa092f3870 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -15,6 +15,7 @@ """ Testing suite for the PyTorch Persimmon model. """ +import gc import unittest from parameterized import parameterized @@ -395,19 +396,27 @@ class PersimmonIntegrationTest(unittest.TestCase): def test_model_8b_chat_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] model = PersimmonForCausalLM.from_pretrained( - "adept/persimmon-8b-chat", device_map="auto", torch_dtype=torch.float16 + "adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16 ) - out = model(torch.tensor([input_ids])).logits + out = model(torch.tensor([input_ids], device=torch_device)).logits EXPECTED_MEAN = torch.tensor( - [[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float16 + [[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]] ) - torch.testing.assert_close(out.cpu().mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4) + # change dtype to `torch.float32` before calling `mean` to avoid `nan` values + torch.testing.assert_close(out.cpu().to(torch.float32).mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4) # fmt: off - EXPECTED_SLICE = torch.tensor([-16.9670, -16.9647, -16.9649, -16.9630, -16.9577, -16.9623, -17.0164, -16.9673, -16.9648, -16.9668, -17.0160, -16.9651, -17.0156, -16.9668, -16.9655, -16.9653, -16.9665, -16.9682, -17.0112, -16.9667, -16.9717, -16.9654, -16.9650, -16.9701, -16.9657, -17.0160, -16.9676, -17.0138, -16.9610, -16.9695]) + EXPECTED_SLICE = torch.tensor( + [-16.9062, -16.9062, -16.9062, -16.9062, -16.8906, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062], + dtype=torch.float16 + ) # fmt: on torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5) + torch.cuda.empty_cache() + del model + gc.collect() + @slow @require_torch_gpu def test_model_8b_chat_greedy_generation(self): @@ -415,11 +424,15 @@ class PersimmonIntegrationTest(unittest.TestCase): prompt = "human: Simply put, the theory of relativity states that?\n\nadept:" tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device) - model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-chat", torch_dtype=torch.float16).to( - torch_device + model = PersimmonForCausalLM.from_pretrained( + "adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16 ) # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=64) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + torch.cuda.empty_cache() + del model + gc.collect()