[CodeGen] support device_map="auto" for sharded checkpoints (#17871)
This commit is contained in:
@@ -332,6 +332,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = CodeGenConfig
|
config_class = CodeGenConfig
|
||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["CodeGenBlock"]
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user