[CodeGen] support device_map="auto" for sharded checkpoints (#17871)

This commit is contained in:
Suraj Patil
2022-06-24 18:06:30 +02:00
committed by GitHub
parent d6b6fb9963
commit 061a73d16f

View File

@@ -332,6 +332,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
config_class = CodeGenConfig config_class = CodeGenConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)