[CodeGen] support device_map="auto" for sharded checkpoints (#17871)
This commit is contained in:
@@ -332,6 +332,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
||||
config_class = CodeGenConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CodeGenBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user