From 30b453206d224ee5f747afa33ff216671558e6a0 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Fri, 19 Apr 2024 01:24:44 -0700 Subject: [PATCH] Enable multi-device for some models (#30207) * feat: multidevice for resnet * feat: yes! resnet * fix: compare all elements in tuple * feat: support for regnet * feat: support for convnextv2 * feat: support for bit * feat: support for cvt * feat: add support for focalnet * feat: support for yolos * feat: support for glpn * feat: support for imagegpt * feat: support for levit * feat: support for mgp_str * feat: support for mobilnet_v1 * feat: support for mobilnet_v2 * feat: support for mobilevit * feat: support for mobilevitv2 * feat: support for poolformer * fix: copies * fix: code quality check * update: upstream changes from main * fix: consistency check * feat: support for sam * feat: support for switchformer * feat: support for swin * feat: support for swinv2 * feat: support for timesformer * feat: suport for trocr * feat: support for upernet * fix: check copies * update: rerun CI * update: rerun again, maybe * update: one more rerun --------- Co-authored-by: Jacky Lee --- src/transformers/models/bit/modeling_bit.py | 1 + .../models/convnext/modeling_convnext.py | 1 + .../models/convnextv2/modeling_convnextv2.py | 1 + src/transformers/models/cvt/modeling_cvt.py | 1 + .../models/donut/modeling_donut_swin.py | 1 + .../models/focalnet/modeling_focalnet.py | 1 + src/transformers/models/glpn/modeling_glpn.py | 1 + .../models/imagegpt/modeling_imagegpt.py | 1 + .../models/levit/modeling_levit.py | 1 + .../maskformer/modeling_maskformer_swin.py | 1 + .../models/mgp_str/modeling_mgp_str.py | 1 + .../mobilenet_v1/modeling_mobilenet_v1.py | 1 + .../mobilenet_v2/modeling_mobilenet_v2.py | 1 + .../models/mobilevit/modeling_mobilevit.py | 1 + .../mobilevitv2/modeling_mobilevitv2.py | 1 + .../models/poolformer/modeling_poolformer.py | 1 + .../models/regnet/modeling_regnet.py | 1 + .../models/resnet/modeling_resnet.py | 1 + src/transformers/models/sam/modeling_sam.py | 1 + .../swiftformer/modeling_swiftformer.py | 1 + src/transformers/models/swin/modeling_swin.py | 1 + .../models/swinv2/modeling_swinv2.py | 1 + .../timesformer/modeling_timesformer.py | 1 + .../models/trocr/modeling_trocr.py | 1 + .../models/upernet/modeling_upernet.py | 1 + .../models/yolos/modeling_yolos.py | 1 + tests/test_modeling_common.py | 20 +++++++++++++++---- 27 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 27141a9009..5906aae5e5 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -658,6 +658,7 @@ class BitPreTrainedModel(PreTrainedModel): config_class = BitConfig base_model_prefix = "bit" main_input_name = "pixel_values" + _no_split_modules = ["BitEmbeddings"] def _init_weights(self, module): if isinstance(module, nn.Conv2d): diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index 147d2ac22d..7aee810ab9 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -280,6 +280,7 @@ class ConvNextPreTrainedModel(PreTrainedModel): config_class = ConvNextConfig base_model_prefix = "convnext" main_input_name = "pixel_values" + _no_split_modules = ["ConvNextLayer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index 7439f21297..ef878748a4 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -301,6 +301,7 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): config_class = ConvNextV2Config base_model_prefix = "convnextv2" main_input_name = "pixel_values" + _no_split_modules = ["ConvNextV2Layer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 25cf3963cb..c2d1dd56d2 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -534,6 +534,7 @@ class CvtPreTrainedModel(PreTrainedModel): config_class = CvtConfig base_model_prefix = "cvt" main_input_name = "pixel_values" + _no_split_modules = ["CvtLayer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index b2aa8d61b1..bf293ae167 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -809,6 +809,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel): base_model_prefix = "swin" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["DonutSwinStage"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index a452f4171d..ef3e2de52f 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -636,6 +636,7 @@ class FocalNetPreTrainedModel(PreTrainedModel): base_model_prefix = "focalnet" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["FocalNetStage"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index e5d30b6272..0791cc0434 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -426,6 +426,7 @@ class GLPNPreTrainedModel(PreTrainedModel): config_class = GLPNConfig base_model_prefix = "glpn" main_input_name = "pixel_values" + _no_split_modules = [] # Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 3b9be17246..81b4107863 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -491,6 +491,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" main_input_name = "input_ids" supports_gradient_checkpointing = True + _no_split_modules = ["ImageGPTBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 11eda7bcc5..00dccf9eff 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -491,6 +491,7 @@ class LevitPreTrainedModel(PreTrainedModel): config_class = LevitConfig base_model_prefix = "levit" main_input_name = "pixel_values" + _no_split_modules = ["LevitResidualLayer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index b4714860e6..1c358c88de 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -735,6 +735,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["MaskFormerSwinStage"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index e35c414d73..2997e5903c 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -317,6 +317,7 @@ class MgpstrPreTrainedModel(PreTrainedModel): config_class = MgpstrConfig base_model_prefix = "mgp_str" + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py index adfb5c5670..af9d232be8 100755 --- a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py @@ -254,6 +254,7 @@ class MobileNetV1PreTrainedModel(PreTrainedModel): base_model_prefix = "mobilenet_v1" main_input_name = "pixel_values" supports_gradient_checkpointing = False + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py index 789da48401..e555941bac 100755 --- a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py @@ -453,6 +453,7 @@ class MobileNetV2PreTrainedModel(PreTrainedModel): base_model_prefix = "mobilenet_v2" main_input_name = "pixel_values" supports_gradient_checkpointing = False + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 939982148c..04105effff 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -644,6 +644,7 @@ class MobileViTPreTrainedModel(PreTrainedModel): base_model_prefix = "mobilevit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["MobileViTLayer"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index c6c446b186..1943f52f51 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -606,6 +606,7 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): base_model_prefix = "mobilevitv2" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["MobileViTV2Layer"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 80208bd1fc..86297e7332 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -268,6 +268,7 @@ class PoolFormerPreTrainedModel(PreTrainedModel): config_class = PoolFormerConfig base_model_prefix = "poolformer" main_input_name = "pixel_values" + _no_split_modules = ["PoolFormerLayer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 915e4cbae4..2e05f8329a 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -281,6 +281,7 @@ class RegNetPreTrainedModel(PreTrainedModel): config_class = RegNetConfig base_model_prefix = "regnet" main_input_name = "pixel_values" + _no_split_modules = ["RegNetYLayer"] # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index ab2ff4814e..560e807c24 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -272,6 +272,7 @@ class ResNetPreTrainedModel(PreTrainedModel): config_class = ResNetConfig base_model_prefix = "resnet" main_input_name = "pixel_values" + _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"] def _init_weights(self, module): if isinstance(module, nn.Conv2d): diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 385fb9c00a..3203031cc9 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1074,6 +1074,7 @@ class SamPreTrainedModel(PreTrainedModel): config_class = SamConfig base_model_prefix = "sam" main_input_name = "pixel_values" + _no_split_modules = ["SamVisionAttention"] def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index c447c0ce12..0455a31641 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -428,6 +428,7 @@ class SwiftFormerPreTrainedModel(PreTrainedModel): base_model_prefix = "swiftformer" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["SwiftFormerEncoderBlock"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index c841faddf0..f21029dcbf 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -884,6 +884,7 @@ class SwinPreTrainedModel(PreTrainedModel): base_model_prefix = "swin" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["SwinStage"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index a83965ede7..83b8ed5ec3 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -939,6 +939,7 @@ class Swinv2PreTrainedModel(PreTrainedModel): base_model_prefix = "swinv2" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["Swinv2Stage"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 3374472508..17b80ee5a1 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -472,6 +472,7 @@ class TimesformerPreTrainedModel(PreTrainedModel): base_model_prefix = "timesformer" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["TimesformerLayer"] def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv2d)): diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 72ead7143a..c80171292b 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -407,6 +407,7 @@ class TrOCRPreTrainedModel(PreTrainedModel): config_class = TrOCRConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["TrOCRDecoderLayer"] def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 2d5b4443e3..58f64995ae 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -293,6 +293,7 @@ class UperNetPreTrainedModel(PreTrainedModel): config_class = UperNetConfig main_input_name = "pixel_values" + _no_split_modules = [] def _init_weights(self, module): if isinstance(module, UperNetPreTrainedModel): diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index f47b6b228f..fe558b33a3 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -533,6 +533,7 @@ class YolosPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 71cb28d754..f1e9c7f2d1 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2907,7 +2907,10 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict_class) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple): + self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) @require_accelerate @mark.accelerate_tests @@ -2939,7 +2942,10 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict_class) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple): + self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) @require_accelerate @mark.accelerate_tests @@ -2975,7 +2981,10 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict_class) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple): + self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) @require_accelerate @mark.accelerate_tests @@ -3011,7 +3020,10 @@ class ModelTesterMixin: torch.manual_seed(0) new_output = new_model(**inputs_dict_class) - self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple): + self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) def test_problem_types(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()