Enable multi-device for some models (#30207)
* feat: multidevice for resnet * feat: yes! resnet * fix: compare all elements in tuple * feat: support for regnet * feat: support for convnextv2 * feat: support for bit * feat: support for cvt * feat: add support for focalnet * feat: support for yolos * feat: support for glpn * feat: support for imagegpt * feat: support for levit * feat: support for mgp_str * feat: support for mobilnet_v1 * feat: support for mobilnet_v2 * feat: support for mobilevit * feat: support for mobilevitv2 * feat: support for poolformer * fix: copies * fix: code quality check * update: upstream changes from main * fix: consistency check * feat: support for sam * feat: support for switchformer * feat: support for swin * feat: support for swinv2 * feat: support for timesformer * feat: suport for trocr * feat: support for upernet * fix: check copies * update: rerun CI * update: rerun again, maybe * update: one more rerun --------- Co-authored-by: Jacky Lee <jackylee328@gmail.com>
This commit is contained in:
@@ -658,6 +658,7 @@ class BitPreTrainedModel(PreTrainedModel):
|
||||
config_class = BitConfig
|
||||
base_model_prefix = "bit"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["BitEmbeddings"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
|
||||
@@ -280,6 +280,7 @@ class ConvNextPreTrainedModel(PreTrainedModel):
|
||||
config_class = ConvNextConfig
|
||||
base_model_prefix = "convnext"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["ConvNextLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -301,6 +301,7 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
|
||||
config_class = ConvNextV2Config
|
||||
base_model_prefix = "convnextv2"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["ConvNextV2Layer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -534,6 +534,7 @@ class CvtPreTrainedModel(PreTrainedModel):
|
||||
config_class = CvtConfig
|
||||
base_model_prefix = "cvt"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["CvtLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -809,6 +809,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "swin"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DonutSwinStage"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -636,6 +636,7 @@ class FocalNetPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "focalnet"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FocalNetStage"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -426,6 +426,7 @@ class GLPNPreTrainedModel(PreTrainedModel):
|
||||
config_class = GLPNConfig
|
||||
base_model_prefix = "glpn"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = []
|
||||
|
||||
# Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
|
||||
@@ -491,6 +491,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ImageGPTBlock"]
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@@ -491,6 +491,7 @@ class LevitPreTrainedModel(PreTrainedModel):
|
||||
config_class = LevitConfig
|
||||
base_model_prefix = "levit"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["LevitResidualLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -735,6 +735,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MaskFormerSwinStage"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -317,6 +317,7 @@ class MgpstrPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = MgpstrConfig
|
||||
base_model_prefix = "mgp_str"
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -254,6 +254,7 @@ class MobileNetV1PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "mobilenet_v1"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -453,6 +453,7 @@ class MobileNetV2PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "mobilenet_v2"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -644,6 +644,7 @@ class MobileViTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "mobilevit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MobileViTLayer"]
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -606,6 +606,7 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "mobilevitv2"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MobileViTV2Layer"]
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -268,6 +268,7 @@ class PoolFormerPreTrainedModel(PreTrainedModel):
|
||||
config_class = PoolFormerConfig
|
||||
base_model_prefix = "poolformer"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["PoolFormerLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -281,6 +281,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
|
||||
config_class = RegNetConfig
|
||||
base_model_prefix = "regnet"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["RegNetYLayer"]
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
|
||||
@@ -272,6 +272,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
||||
config_class = ResNetConfig
|
||||
base_model_prefix = "resnet"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
|
||||
@@ -1074,6 +1074,7 @@ class SamPreTrainedModel(PreTrainedModel):
|
||||
config_class = SamConfig
|
||||
base_model_prefix = "sam"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["SamVisionAttention"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
||||
@@ -428,6 +428,7 @@ class SwiftFormerPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "swiftformer"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SwiftFormerEncoderBlock"]
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -884,6 +884,7 @@ class SwinPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "swin"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SwinStage"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -939,6 +939,7 @@ class Swinv2PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "swinv2"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Swinv2Stage"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -472,6 +472,7 @@ class TimesformerPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "timesformer"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["TimesformerLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
|
||||
@@ -407,6 +407,7 @@ class TrOCRPreTrainedModel(PreTrainedModel):
|
||||
config_class = TrOCRConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["TrOCRDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
||||
@@ -293,6 +293,7 @@ class UperNetPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = UperNetConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, UperNetPreTrainedModel):
|
||||
|
||||
@@ -533,6 +533,7 @@ class YolosPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
||||
@@ -2907,6 +2907,9 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@@ -2939,6 +2942,9 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@@ -2975,6 +2981,9 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@@ -3011,6 +3020,9 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
def test_problem_types(self):
|
||||
|
||||
Reference in New Issue
Block a user