add accelerate support for Whisper (#19697)
This commit is contained in:
@@ -446,6 +446,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_features"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["WhisperEncoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
||||
Reference in New Issue
Block a user