Add offload support to Bark (#25037)
* initial Bark offload proposal * use hooks instead of manually offloading * add test of bark offload to cpu feature * Apply nit suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docstrings of offload Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary set_seed in Bark tests --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,7 @@ from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkFineGenerationConfig,
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -989,3 +989,42 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
coarse_temperature=0.2,
|
||||
fine_temperature=0.1,
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_generate_end_to_end_with_offload(self):
|
||||
input_ids = self.inputs
|
||||
|
||||
with torch.no_grad():
|
||||
# standard generation
|
||||
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
memory_before_offload = torch.cuda.memory_allocated()
|
||||
model_memory_footprint = self.model.get_memory_footprint()
|
||||
|
||||
# activate cpu offload
|
||||
self.model.enable_cpu_offload()
|
||||
|
||||
memory_after_offload = torch.cuda.memory_allocated()
|
||||
|
||||
# checks if the model have been offloaded
|
||||
|
||||
# CUDA memory usage after offload should be near 0, leaving room to small differences
|
||||
room_for_difference = 1.1
|
||||
self.assertGreater(
|
||||
(memory_before_offload - model_memory_footprint) * room_for_difference, memory_after_offload
|
||||
)
|
||||
|
||||
# checks if device is the correct one
|
||||
self.assertEqual(self.model.device.type, torch_device)
|
||||
|
||||
# checks if hooks exist
|
||||
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))
|
||||
|
||||
# output with cpu offload
|
||||
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
|
||||
|
||||
# checks if same output
|
||||
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())
|
||||
|
||||
Reference in New Issue
Block a user