From 89b6ee49fdc3ad625fba8067743641facdbb8597 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 27 Jun 2023 21:35:15 -0400 Subject: [PATCH] Finishing tidying keys to ignore on load (#24535) --- .../models/instructblip/modeling_instructblip.py | 10 +++------- .../timm_backbone/test_modeling_timm_backbone.py | 8 ++++++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index acc0df6309..3eca3b0a32 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -275,12 +275,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel): config_class = InstructBlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [ - r"position_ids", - r"language_model.encoder.embed_tokens.weight", - r"language_model.decoder.embed_tokens.weight", - r"language_model.lm_head.weight", - ] _no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"] _keep_in_fp32_modules = [] @@ -1011,7 +1005,9 @@ class InstructBlipQFormerEmbeddings(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 145238c6bf..2af7c4c617 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -176,6 +176,14 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste def test_tied_model_weights_key_ignore(self): pass + @unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone") + def test_model_weights_reload_no_missing_tied_weights(self): + pass + @unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.") def test_channels(self): pass