Fix GPU OOM for mistral.py::Mask4DTestHard (#31212)
* build * build * build * build --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -734,15 +734,24 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class Mask4DTestHard(unittest.TestCase):
|
class Mask4DTestHard(unittest.TestCase):
|
||||||
|
model_name = "mistralai/Mistral-7B-v0.1"
|
||||||
|
_model = None
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
if self.__class__._model is None:
|
||||||
|
self.__class__._model = MistralForCausalLM.from_pretrained(
|
||||||
|
self.model_name, torch_dtype=self.model_dtype
|
||||||
|
).to(torch_device)
|
||||||
|
return self.__class__._model
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
model_name = "mistralai/Mistral-7B-v0.1"
|
self.model_dtype = torch.float16
|
||||||
self.model_dtype = torch.float32
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
|
||||||
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
|
||||||
|
|
||||||
def get_test_data(self):
|
def get_test_data(self):
|
||||||
template = "my favorite {}"
|
template = "my favorite {}"
|
||||||
|
|||||||
Reference in New Issue
Block a user