Enable instantiating model with pretrained backbone weights (#28214)
* Enable instantiating model with pretrained backbone weights * Update tests so backbone checkpoint isn't passed in * Remove doc updates until changes made in modeling code * Clarify pretrained import * Update configs - docs and validation check * Update src/transformers/utils/backbone_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clarify exception message * Update config init in tests * Add test for when use_timm_backbone=True * Small test updates --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -134,6 +134,8 @@ class ConditionalDetrModelTester:
|
||||
num_labels=self.num_labels,
|
||||
use_timm_backbone=False,
|
||||
backbone_config=resnet_config,
|
||||
backbone=None,
|
||||
use_pretrained_backbone=False,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
||||
@@ -149,7 +149,9 @@ class DeformableDetrModelTester:
|
||||
encoder_n_points=self.encoder_n_points,
|
||||
decoder_n_points=self.decoder_n_points,
|
||||
use_timm_backbone=False,
|
||||
backbone=None,
|
||||
backbone_config=resnet_config,
|
||||
use_pretrained_backbone=False,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -518,6 +520,8 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_config = None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
@@ -157,6 +157,7 @@ class DetaModelTester:
|
||||
assign_first_stage=assign_first_stage,
|
||||
assign_second_stage=assign_second_stage,
|
||||
backbone_config=resnet_config,
|
||||
backbone=None,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
|
||||
|
||||
@@ -130,6 +130,8 @@ class DetrModelTester:
|
||||
num_labels=self.num_labels,
|
||||
use_timm_backbone=False,
|
||||
backbone_config=resnet_config,
|
||||
backbone=None,
|
||||
use_pretrained_backbone=False,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -622,7 +624,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
expected_number_of_segments = 5
|
||||
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096}
|
||||
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994097}
|
||||
|
||||
number_of_unique_segments = len(torch.unique(results["segmentation"]))
|
||||
self.assertTrue(
|
||||
|
||||
@@ -95,6 +95,7 @@ class DPTModelTester:
|
||||
def get_config(self):
|
||||
return DPTConfig(
|
||||
backbone_config=self.get_backbone_config(),
|
||||
backbone=None,
|
||||
neck_hidden_sizes=self.neck_hidden_sizes,
|
||||
fusion_hidden_size=self.fusion_hidden_size,
|
||||
)
|
||||
|
||||
@@ -130,6 +130,7 @@ class DPTModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
is_hybrid=self.is_hybrid,
|
||||
backbone_config=backbone_config,
|
||||
backbone=None,
|
||||
backbone_featmap_shape=self.backbone_featmap_shape,
|
||||
neck_hidden_sizes=self.neck_hidden_sizes,
|
||||
)
|
||||
|
||||
@@ -114,6 +114,7 @@ class Mask2FormerModelTester:
|
||||
config.backbone_config.hidden_size = 16
|
||||
config.backbone_config.num_channels = self.num_channels
|
||||
config.backbone_config.num_heads = [1, 1, 2, 2]
|
||||
config.backbone = None
|
||||
|
||||
config.hidden_dim = self.hidden_dim
|
||||
config.mask_feature_size = self.hidden_dim
|
||||
|
||||
@@ -102,6 +102,7 @@ class MaskFormerModelTester:
|
||||
hidden_size=32,
|
||||
num_heads=[1, 1, 2, 2],
|
||||
),
|
||||
backbone=None,
|
||||
decoder_config=DetrConfig(
|
||||
decoder_ffn_dim=64,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
|
||||
@@ -133,6 +133,7 @@ class OneFormerModelTester:
|
||||
config.backbone_config.hidden_size = 16
|
||||
config.backbone_config.num_channels = self.num_channels
|
||||
config.backbone_config.num_heads = [1, 1, 2, 2]
|
||||
config.backbone = None
|
||||
|
||||
config.hidden_dim = self.hidden_dim
|
||||
config.mask_dim = self.hidden_dim
|
||||
|
||||
@@ -131,6 +131,8 @@ class TableTransformerModelTester:
|
||||
num_labels=self.num_labels,
|
||||
use_timm_backbone=False,
|
||||
backbone_config=resnet_config,
|
||||
backbone=None,
|
||||
use_pretrained_backbone=False,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
||||
@@ -124,6 +124,7 @@ class TVPModelTester:
|
||||
)
|
||||
return TvpConfig(
|
||||
backbone_config=resnet_config,
|
||||
backbone=None,
|
||||
alpha=self.alpha,
|
||||
beta=self.beta,
|
||||
visual_prompter_type=self.visual_prompter_type,
|
||||
|
||||
@@ -105,6 +105,7 @@ class UperNetModelTester:
|
||||
def get_config(self):
|
||||
return UperNetConfig(
|
||||
backbone_config=self.get_backbone_config(),
|
||||
backbone=None,
|
||||
hidden_size=64,
|
||||
pool_scales=[1, 2, 3, 6],
|
||||
use_auxiliary_head=True,
|
||||
|
||||
@@ -122,6 +122,7 @@ class ViTHybridModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
backbone_featmap_shape=self.backbone_featmap_shape,
|
||||
backbone_config=backbone_config,
|
||||
backbone=None,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
@@ -111,6 +111,7 @@ class VitMatteModelTester:
|
||||
def get_config(self):
|
||||
return VitMatteConfig(
|
||||
backbone_config=self.get_backbone_config(),
|
||||
backbone=None,
|
||||
hidden_size=self.hidden_size,
|
||||
fusion_hidden_sizes=self.fusion_hidden_sizes,
|
||||
)
|
||||
|
||||
@@ -16,11 +16,21 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import DetrConfig, MaskFormerConfig
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
from transformers.utils.backbone_utils import (
|
||||
BackboneMixin,
|
||||
get_aligned_output_features_output_indices,
|
||||
load_backbone,
|
||||
verify_out_features_out_indices,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BertPreTrainedModel
|
||||
|
||||
|
||||
class BackboneUtilsTester(unittest.TestCase):
|
||||
@@ -126,3 +136,75 @@ class BackboneUtilsTester(unittest.TestCase):
|
||||
backbone.out_indices = [-3, -1]
|
||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||
self.assertEqual(backbone.out_indices, [-3, -1])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_backbone_in_new_model(self):
|
||||
"""
|
||||
Tests that new model can be created, with its weights instantiated and pretrained backbone weights loaded.
|
||||
"""
|
||||
|
||||
# Inherit from PreTrainedModel to ensure that the weights are initialized
|
||||
class NewModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.backbone = load_backbone(config)
|
||||
self.layer_0 = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_1 = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def get_equal_not_equal_weights(model_0, model_1):
|
||||
equal_weights = []
|
||||
not_equal_weights = []
|
||||
for (k0, v0), (k1, v1) in zip(model_0.named_parameters(), model_1.named_parameters()):
|
||||
self.assertEqual(k0, k1)
|
||||
weights_are_equal = torch.allclose(v0, v1)
|
||||
if weights_are_equal:
|
||||
equal_weights.append(k0)
|
||||
else:
|
||||
not_equal_weights.append(k0)
|
||||
return equal_weights, not_equal_weights
|
||||
|
||||
config = MaskFormerConfig(use_pretrained_backbone=False, backbone="microsoft/resnet-18")
|
||||
model_0 = NewModel(config)
|
||||
model_1 = NewModel(config)
|
||||
equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
|
||||
|
||||
# Norm layers are always initialized with the same weights
|
||||
equal_weights = [w for w in equal_weights if "normalization" not in w]
|
||||
self.assertEqual(len(equal_weights), 0)
|
||||
self.assertEqual(len(not_equal_weights), 24)
|
||||
|
||||
# Now we create a new model with backbone weights that are pretrained
|
||||
config.use_pretrained_backbone = True
|
||||
model_0 = NewModel(config)
|
||||
model_1 = NewModel(config)
|
||||
equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
|
||||
|
||||
# Norm layers are always initialized with the same weights
|
||||
equal_weights = [w for w in equal_weights if "normalization" not in w]
|
||||
self.assertEqual(len(equal_weights), 20)
|
||||
# Linear layers are still initialized randomly
|
||||
self.assertEqual(len(not_equal_weights), 4)
|
||||
|
||||
# Check loading in timm backbone
|
||||
config = DetrConfig(use_pretrained_backbone=False, backbone="resnet18", use_timm_backbone=True)
|
||||
model_0 = NewModel(config)
|
||||
model_1 = NewModel(config)
|
||||
equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
|
||||
|
||||
# Norm layers are always initialized with the same weights
|
||||
equal_weights = [w for w in equal_weights if "bn" not in w and "downsample.1" not in w]
|
||||
self.assertEqual(len(equal_weights), 0)
|
||||
self.assertEqual(len(not_equal_weights), 24)
|
||||
|
||||
# Now we create a new model with backbone weights that are pretrained
|
||||
config.use_pretrained_backbone = True
|
||||
model_0 = NewModel(config)
|
||||
model_1 = NewModel(config)
|
||||
equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
|
||||
|
||||
# Norm layers are always initialized with the same weights
|
||||
equal_weights = [w for w in equal_weights if "bn" not in w and "downsample.1" not in w]
|
||||
self.assertEqual(len(equal_weights), 20)
|
||||
# Linear layers are still initialized randomly
|
||||
self.assertEqual(len(not_equal_weights), 4)
|
||||
|
||||
Reference in New Issue
Block a user