[deepspeed zero3] need generate(synced_gpus=True, ...) (#22242)
* [deepspeed zero3] need generate(synced_gpus=True, ...) * fix * rework per Sylvain's suggestion * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -2268,6 +2268,14 @@ rank1:
|
|||||||
|
|
||||||
This was a very basic example and you will want to adapt it to your needs.
|
This was a very basic example and you will want to adapt it to your needs.
|
||||||
|
|
||||||
|
### `generate` nuances
|
||||||
|
|
||||||
|
When using multiple GPUs with ZeRO Stage-3, one has to synchronize the GPUs by calling `generate(..., synced_gpus=True)`. If this is not done if one GPU finished generating before other GPUs the whole system will hang as the rest of the GPUs will not be able to received the shard of weights from the GPU that stopped generating.
|
||||||
|
|
||||||
|
Starting from `transformers>=4.28`, if `synced_gpus` isn't explicitly specified, it'll be set to `True` automatically if these conditions are detected. But you can still override the value of `synced_gpus` if need to.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Testing Deepspeed Integration
|
## Testing Deepspeed Integration
|
||||||
|
|
||||||
If you submit a PR that involves DeepSpeed integration please note our CircleCI PR CI setup has no GPUs, so we only run tests requiring gpus on a different CI nightly. Therefore if you get a green CI report in your PR it doesn't mean DeepSpeed tests pass.
|
If you submit a PR that involves DeepSpeed integration please note our CircleCI PR CI setup has no GPUs, so we only run tests requiring gpus on a different CI nightly. Therefore if you get a green CI report in your PR it doesn't mean DeepSpeed tests pass.
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ..deepspeed import is_deepspeed_zero3_enabled
|
||||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||||
from ..models.auto import (
|
from ..models.auto import (
|
||||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||||
@@ -1114,7 +1115,7 @@ class GenerationMixin:
|
|||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1160,8 +1161,11 @@ class GenerationMixin:
|
|||||||
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
|
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
|
||||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
synced_gpus (`bool`, *optional*):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
||||||
|
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
||||||
|
generating before other GPUs. Otherwise it'll be set to `False`.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||||
@@ -1187,6 +1191,13 @@ class GenerationMixin:
|
|||||||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if synced_gpus is None:
|
||||||
|
if is_deepspeed_zero3_enabled() and dist.world_size() > 1:
|
||||||
|
synced_gpus = True
|
||||||
|
else:
|
||||||
|
synced_gpus = False
|
||||||
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user