From db5e0c329251a7fdd96495653f5edcb0b3c8fc5f Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 12 Oct 2023 12:31:11 +0200 Subject: [PATCH] Fix `MistralIntegrationTest` OOM (#26754) * fix * fix * fix --------- Co-authored-by: ydshieh --- tests/models/mistral/test_modeling_mistral.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index d2e9b2685f..311ed55892 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -15,6 +15,7 @@ """ Testing suite for the PyTorch Mistral model. """ +import gc import tempfile import unittest @@ -447,17 +448,23 @@ class MistralIntegrationTest(unittest.TestCase): print(out[0, 0, :30]) torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4) + del model + torch.cuda.empty_cache() + gc.collect() + @slow def test_model_7b_generation(self): - EXPECTED_TEXT_COMPLETION = ( - """My favourite condiment is mayonnaise. I love it on sandwiches, in salads, on burgers""" - ) + EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big""" prompt = "My favourite condiment is " tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device) - model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1").to(torch_device) + model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto") # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + torch.cuda.empty_cache() + gc.collect()