diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index a09f37d115..73af767e04 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -31,6 +31,7 @@ from transformers import ( ) from transformers.testing_utils import ( Expectations, + cleanup, get_device_properties, is_torch_available, is_torchaudio_available, @@ -1256,6 +1257,12 @@ class MusicgenMelodyIntegrationTests(unittest.TestCase): def model(self): return MusicgenMelodyForConditionalGeneration.from_pretrained("ylacombe/musicgen-melody").to(torch_device) + def setUp(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + @cached_property def processor(self): return MusicgenMelodyProcessor.from_pretrained("ylacombe/musicgen-melody")