From 73fdc8c5b4509580a047696a4c672a80a8f1710c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 22 Mar 2023 12:18:57 -0700 Subject: [PATCH] [deepspeed zero3] need `generate(synced_gpus=True, ...)` (#22242) * [deepspeed zero3] need generate(synced_gpus=True, ...) * fix * rework per Sylvain's suggestion * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/main_classes/deepspeed.mdx | 8 ++++++++ src/transformers/generation/utils.py | 17 ++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/source/en/main_classes/deepspeed.mdx b/docs/source/en/main_classes/deepspeed.mdx index 2b5ea6a743..2c78078d60 100644 --- a/docs/source/en/main_classes/deepspeed.mdx +++ b/docs/source/en/main_classes/deepspeed.mdx @@ -2268,6 +2268,14 @@ rank1: This was a very basic example and you will want to adapt it to your needs. +### `generate` nuances + +When using multiple GPUs with ZeRO Stage-3, one has to synchronize the GPUs by calling `generate(..., synced_gpus=True)`. If this is not done if one GPU finished generating before other GPUs the whole system will hang as the rest of the GPUs will not be able to received the shard of weights from the GPU that stopped generating. + +Starting from `transformers>=4.28`, if `synced_gpus` isn't explicitly specified, it'll be set to `True` automatically if these conditions are detected. But you can still override the value of `synced_gpus` if need to. + + + ## Testing Deepspeed Integration If you submit a PR that involves DeepSpeed integration please note our CircleCI PR CI setup has no GPUs, so we only run tests requiring gpus on a different CI nightly. Therefore if you get a green CI report in your PR it doesn't mean DeepSpeed tests pass. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eb0e0ed4c2..6140f4cb40 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -24,6 +24,7 @@ import torch import torch.distributed as dist from torch import nn +from ..deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..models.auto import ( MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -1114,7 +1115,7 @@ class GenerationMixin: logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = False, + synced_gpus: Optional[bool] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -1160,8 +1161,11 @@ class GenerationMixin: on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + synced_gpus (`bool`, *optional*): + Whether to continue running the while loop until max_length. Unless overridden this flag will be set to + `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished + generating before other GPUs. Otherwise it'll be set to `False`. + kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1187,6 +1191,13 @@ class GenerationMixin: - [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSampleEncoderDecoderOutput`] """ + + if synced_gpus is None: + if is_deepspeed_zero3_enabled() and dist.world_size() > 1: + synced_gpus = True + else: + synced_gpus = False + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class()