skip low_cpu_mem_usage tests (#30782)

This commit is contained in:
Marc Sun
2024-05-13 18:00:43 +02:00
committed by GitHub
parent 0f8fefd481
commit 539ed75d50

View File

@@ -21,6 +21,7 @@ import os.path
import random import random
import re import re
import tempfile import tempfile
import unittest
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@@ -440,6 +441,7 @@ class ModelTesterMixin:
@slow @slow
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage(self): def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path: with tempfile.TemporaryDirectory() as saved_model_path:
@@ -452,6 +454,7 @@ class ModelTesterMixin:
@slow @slow
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_checkpoints(self): def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path: with tempfile.TemporaryDirectory() as saved_model_path:
@@ -465,6 +468,7 @@ class ModelTesterMixin:
@slow @slow
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_no_safetensors(self): def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path: with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes: for model_class in self.all_model_classes: