diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 41fab31493..e9d774b562 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -446,6 +446,7 @@ class WhisperPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True + _no_split_modules = ["WhisperEncoderLayer"] def _init_weights(self, module): std = self.config.init_std