add accelerate support for Whisper (#19697)

This commit is contained in:
Younes Belkada
2022-10-18 18:25:49 +02:00
committed by GitHub
parent fb0bd7b7a8
commit af556a09f6

View File

@@ -446,6 +446,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer"]
def _init_weights(self, module):
std = self.config.init_std