Add offload support to Bark (#25037)
* initial Bark offload proposal * use hooks instead of manually offloading * add test of bark offload to cpu feature * Apply nit suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docstrings of offload Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary set_seed in Bark tests --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -23,8 +23,13 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor
|
from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor
|
||||||
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_accelerate_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from ..auto import AutoModel
|
from ..auto import AutoModel
|
||||||
from .configuration_bark import (
|
from .configuration_bark import (
|
||||||
BarkCoarseConfig,
|
BarkCoarseConfig,
|
||||||
@@ -288,6 +293,26 @@ class BarkPreTrainedModel(PreTrainedModel):
|
|||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||||
|
device).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if has _hf_hook, has been offloaded so the device has to be found in the hook
|
||||||
|
if not hasattr(self, "_hf_hook"):
|
||||||
|
return get_parameter_device(self)
|
||||||
|
for module in self.modules():
|
||||||
|
if (
|
||||||
|
hasattr(module, "_hf_hook")
|
||||||
|
and hasattr(module._hf_hook, "execution_device")
|
||||||
|
and module._hf_hook.execution_device is not None
|
||||||
|
):
|
||||||
|
return torch.device(module._hf_hook.execution_device)
|
||||||
|
|
||||||
|
return get_parameter_device(self)
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
|
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
@@ -1376,6 +1401,63 @@ class BarkModel(BarkPreTrainedModel):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||||
|
device).
|
||||||
|
"""
|
||||||
|
# for bark_model, device must be verified on its sub-models
|
||||||
|
# if has _hf_hook, has been offloaded so the device has to be found in the hook
|
||||||
|
if not hasattr(self.semantic, "_hf_hook"):
|
||||||
|
return get_parameter_device(self)
|
||||||
|
for module in self.semantic.modules():
|
||||||
|
if (
|
||||||
|
hasattr(module, "_hf_hook")
|
||||||
|
and hasattr(module._hf_hook, "execution_device")
|
||||||
|
and module._hf_hook.execution_device is not None
|
||||||
|
):
|
||||||
|
return torch.device(module._hf_hook.execution_device)
|
||||||
|
|
||||||
|
def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
|
||||||
|
r"""
|
||||||
|
Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
|
||||||
|
method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until
|
||||||
|
the next sub-model runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id (`int`, *optional*, defaults to 0):
|
||||||
|
GPU id on which the sub-models will be loaded and offloaded.
|
||||||
|
"""
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import cpu_offload_with_hook
|
||||||
|
else:
|
||||||
|
raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
|
||||||
|
if self.device.type != "cpu":
|
||||||
|
self.to("cpu")
|
||||||
|
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||||
|
|
||||||
|
# this layer is used outside the first foward pass of semantic so need to be loaded before semantic
|
||||||
|
self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
|
||||||
|
|
||||||
|
hook = None
|
||||||
|
for cpu_offloaded_model in [
|
||||||
|
self.semantic,
|
||||||
|
self.coarse_acoustics,
|
||||||
|
self.fine_acoustics,
|
||||||
|
]:
|
||||||
|
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||||
|
|
||||||
|
self.fine_acoustics_hook = hook
|
||||||
|
|
||||||
|
_, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
|
||||||
|
|
||||||
|
# We'll offload the last model manually.
|
||||||
|
self.codec_model_hook = hook
|
||||||
|
|
||||||
def codec_decode(self, fine_output):
|
def codec_decode(self, fine_output):
|
||||||
"""Turn quantized audio codes into audio array using encodec."""
|
"""Turn quantized audio codes into audio array using encodec."""
|
||||||
|
|
||||||
@@ -1490,9 +1572,20 @@ class BarkModel(BarkPreTrainedModel):
|
|||||||
**kwargs_fine,
|
**kwargs_fine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(self, "fine_acoustics_hook", None) is not None:
|
||||||
|
# Manually offload fine_acoustics to CPU
|
||||||
|
# and load codec_model to GPU
|
||||||
|
# since bark doesn't use codec_model forward pass
|
||||||
|
self.fine_acoustics_hook.offload()
|
||||||
|
self.codec_model = self.codec_model.to(self.device)
|
||||||
|
|
||||||
# 4. Decode the output and generate audio array
|
# 4. Decode the output and generate audio array
|
||||||
audio = self.codec_decode(output)
|
audio = self.codec_decode(output)
|
||||||
|
|
||||||
|
if getattr(self, "codec_model_hook", None) is not None:
|
||||||
|
# Offload codec_model to CPU
|
||||||
|
self.codec_model_hook.offload()
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def can_generate(self) -> bool:
|
def can_generate(self) -> bool:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from transformers.models.bark.generation_configuration_bark import (
|
|||||||
BarkFineGenerationConfig,
|
BarkFineGenerationConfig,
|
||||||
BarkSemanticGenerationConfig,
|
BarkSemanticGenerationConfig,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||||
from transformers.utils import cached_property
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -989,3 +989,42 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
|||||||
coarse_temperature=0.2,
|
coarse_temperature=0.2,
|
||||||
fine_temperature=0.1,
|
fine_temperature=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
def test_generate_end_to_end_with_offload(self):
|
||||||
|
input_ids = self.inputs
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# standard generation
|
||||||
|
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
memory_before_offload = torch.cuda.memory_allocated()
|
||||||
|
model_memory_footprint = self.model.get_memory_footprint()
|
||||||
|
|
||||||
|
# activate cpu offload
|
||||||
|
self.model.enable_cpu_offload()
|
||||||
|
|
||||||
|
memory_after_offload = torch.cuda.memory_allocated()
|
||||||
|
|
||||||
|
# checks if the model have been offloaded
|
||||||
|
|
||||||
|
# CUDA memory usage after offload should be near 0, leaving room to small differences
|
||||||
|
room_for_difference = 1.1
|
||||||
|
self.assertGreater(
|
||||||
|
(memory_before_offload - model_memory_footprint) * room_for_difference, memory_after_offload
|
||||||
|
)
|
||||||
|
|
||||||
|
# checks if device is the correct one
|
||||||
|
self.assertEqual(self.model.device.type, torch_device)
|
||||||
|
|
||||||
|
# checks if hooks exist
|
||||||
|
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))
|
||||||
|
|
||||||
|
# output with cpu offload
|
||||||
|
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
|
||||||
|
|
||||||
|
# checks if same output
|
||||||
|
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())
|
||||||
|
|||||||
Reference in New Issue
Block a user