Add dpt-hybrid support (#20645)

* add `dpt-hybrid` support

* refactor

* final changes, all tests pass

* final cleanups

* final changes

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix docstring

* fix typo

* change `vit_hybrid` to `hybrid`

* replace dataclass

* add docstring

* move dataclasses

* fix test

* add `PretrainedConfig` support for `backbone_config`

* fix docstring

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove `embedding_type` and replace it by `is_hybrid`

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2022-12-07 17:01:55 +01:00
committed by GitHub
parent 3ac040bca1
commit 7c5eaf9e5a
6 changed files with 943 additions and 43 deletions

View File

@@ -61,6 +61,7 @@ class DPTModelTester:
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
num_labels=3,
is_hybrid=False,
scope=None,
):
self.parent = parent
@@ -81,6 +82,7 @@ class DPTModelTester:
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
self.is_hybrid = is_hybrid
# sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
@@ -111,6 +113,7 @@ class DPTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
is_hybrid=self.is_hybrid,
)
def create_and_check_model(self, config, pixel_values, labels):