add accelerate support for Whisper (#19697)
This commit is contained in:
@@ -446,6 +446,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
main_input_name = "input_features"
|
main_input_name = "input_features"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["WhisperEncoderLayer"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
|
|||||||
Reference in New Issue
Block a user