From af556a09f67ad279f14e9bcb93780d45a2b68a28 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:25:49 +0200 Subject: [PATCH] add `accelerate` support for `Whisper` (#19697) --- src/transformers/models/whisper/modeling_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 41fab31493..e9d774b562 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -446,6 +446,7 @@ class WhisperPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True + _no_split_modules = ["WhisperEncoderLayer"] def _init_weights(self, module): std = self.config.init_std