Add offload support to Bark (#25037)

* initial Bark offload proposal

* use hooks instead of manually offloading

* add test of bark offload to cpu feature

* Apply nit suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docstrings of offload

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unecessary set_seed in Bark tests

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2023-07-27 16:35:17 +02:00
committed by GitHub
parent 9cea3e7b80
commit 0b92ae3489
2 changed files with 140 additions and 8 deletions

View File

@@ -31,7 +31,7 @@ from transformers.models.bark.generation_configuration_bark import (
BarkFineGenerationConfig,
BarkSemanticGenerationConfig,
)
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
@@ -989,3 +989,42 @@ class BarkModelIntegrationTests(unittest.TestCase):
coarse_temperature=0.2,
fine_temperature=0.1,
)
@require_torch_gpu
@slow
def test_generate_end_to_end_with_offload(self):
input_ids = self.inputs
with torch.no_grad():
# standard generation
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
torch.cuda.empty_cache()
memory_before_offload = torch.cuda.memory_allocated()
model_memory_footprint = self.model.get_memory_footprint()
# activate cpu offload
self.model.enable_cpu_offload()
memory_after_offload = torch.cuda.memory_allocated()
# checks if the model have been offloaded
# CUDA memory usage after offload should be near 0, leaving room to small differences
room_for_difference = 1.1
self.assertGreater(
(memory_before_offload - model_memory_footprint) * room_for_difference, memory_after_offload
)
# checks if device is the correct one
self.assertEqual(self.model.device.type, torch_device)
# checks if hooks exist
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))
# output with cpu offload
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
# checks if same output
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())