avoid calling gc.collect and cuda.empty_cache (#34514)

* update

* update

* update

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-10-31 16:36:13 +01:00
committed by GitHub
parent dca93ca076
commit ab98f0b0a1
24 changed files with 77 additions and 94 deletions

View File

@@ -14,13 +14,12 @@
# limitations under the License.
"""Testing suite for the PyTorch SAM model."""
import gc
import unittest
import requests
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -469,8 +468,7 @@ class SamModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
backend_empty_cache(torch_device)
cleanup(torch_device, gc_collect=True)
def test_inference_mask_generation_no_point(self):
model = SamModel.from_pretrained("facebook/sam-vit-base")