Add dpt-hybrid support (#20645)
* add `dpt-hybrid` support * refactor * final changes, all tests pass * final cleanups * final changes * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix docstring * fix typo * change `vit_hybrid` to `hybrid` * replace dataclass * add docstring * move dataclasses * fix test * add `PretrainedConfig` support for `backbone_config` * fix docstring * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove `embedding_type` and replace it by `is_hybrid` Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -61,6 +61,7 @@ class DPTModelTester:
|
||||
attention_probs_dropout_prob=0.1,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
is_hybrid=False,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -81,6 +82,7 @@ class DPTModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.is_hybrid = is_hybrid
|
||||
# sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
@@ -111,6 +113,7 @@ class DPTModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
is_hybrid=self.is_hybrid,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
||||
Reference in New Issue
Block a user