diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index ea9a7ef6a2..06581e732c 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -332,6 +332,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): config_class = CodeGenConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True + _no_split_modules = ["CodeGenBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs)