From e7d52a10d721f4475c810d403b1e71689d4b94b9 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 26 Apr 2024 18:04:41 +0100 Subject: [PATCH] Fix GroundingDINO, DPR after BERT SDPA update (#30506) Fix GroundingDINO, DPR after BET SDPA update --- docs/source/en/perf_infer_gpu_one.md | 1 + src/transformers/models/dpr/modeling_dpr.py | 2 ++ .../models/grounding_dino/modeling_grounding_dino.py | 4 +++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 64583e4bad..de49d4427b 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -194,6 +194,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) +* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 0a45ec7520..928f2b9311 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -142,6 +142,8 @@ class DPRReaderOutput(ModelOutput): class DPRPreTrainedModel(PreTrainedModel): + _supports_sdpa = True + def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 83009c9250..da8dd29a5c 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -2113,7 +2113,9 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel): ) # Create text backbone - self.text_backbone = AutoModel.from_config(config.text_config, add_pooling_layer=False) + self.text_backbone = AutoModel.from_config( + config.text_config, add_pooling_layer=False, attn_implementation=config._attn_implementation + ) self.text_projection = nn.Linear(config.text_config.hidden_size, config.d_model) if config.embedding_init_target or not config.two_stage: