From 75a6319864225b8350c31b623ea2c73c23012a40 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:51:42 +0200 Subject: [PATCH] Fix post gemma merge (#31660) * nit * toctree issue * protect gemma2 tests as well * sdpa supported --- docs/source/en/_toctree.yml | 2 ++ docs/source/en/perf_infer_gpu_one.md | 2 ++ tests/models/gemma2/test_modeling_gemma2.py | 11 ++++++----- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 94f5d8d19e..e48378d8c2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -382,6 +382,8 @@ title: Fuyu - local: model_doc/gemma title: Gemma + - local: model_doc/gemma2 + title: Gemma2 - local: model_doc/openai-gpt title: GPT - local: model_doc/gpt_neo diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index add92a9440..1569bef1f6 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -43,6 +43,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) @@ -202,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 6a6c5688d5..870265f946 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -41,11 +41,12 @@ if is_torch_available(): class Gemma2ModelTester(GemmaModelTester): - config_class = Gemma2Config - model_class = Gemma2Model - for_causal_lm_class = Gemma2ForCausalLM - for_sequence_class = Gemma2ForSequenceClassification - for_token_class = Gemma2ForTokenClassification + if is_torch_available(): + config_class = Gemma2Config + model_class = Gemma2Model + for_causal_lm_class = Gemma2ForCausalLM + for_sequence_class = Gemma2ForSequenceClassification + for_token_class = Gemma2ForTokenClassification @require_torch