From 2829013d2d00e63d75a1f6f7a3f003bc60cc69af Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 12 Mar 2025 11:29:11 +0100 Subject: [PATCH] fix block mask typing (#36661) * fix block mask typing * updated Co-authored-by: Cyril Vallez * gemma * fix --------- Co-authored-by: Cyril Vallez --- docs/source/en/_toctree.yml | 139 +++++++++--------- .../integrations/flex_attention.py | 4 +- src/transformers/modeling_utils.py | 6 +- .../models/gemma3/modeling_gemma3.py | 2 +- 4 files changed, 76 insertions(+), 75 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 33c4a7df57..5c4a643507 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1,16 +1,14 @@ -- title: Get started - sections: +- sections: - local: index title: Transformers - local: installation title: Installation - local: quicktour title: Quickstart -- title: Base classes - isExpanded: False + title: Get started +- isExpanded: false sections: - - title: Models - sections: + - sections: - local: models title: Loading models - local: custom_models @@ -31,8 +29,8 @@ title: The Transformer model family - local: attention title: Attention mechanisms - - title: Preprocessors - sections: + title: Models + - sections: - local: fast_tokenizers title: Tokenizers - local: image_processors @@ -47,11 +45,11 @@ title: Summary of the tokenizers - local: pad_truncation title: Padding and truncation -- title: Inference - isExpanded: False + title: Preprocessors + title: Base classes +- isExpanded: false sections: - - title: Pipeline API - sections: + - sections: - local: pipeline_tutorial title: Pipeline - local: pipeline_gradio @@ -60,8 +58,8 @@ title: Web server inference - local: add_new_pipeline title: Adding a new pipeline - - title: LLMs - sections: + title: Pipeline API + - sections: - local: llm_tutorial title: Text generation - local: generation_strategies @@ -82,8 +80,8 @@ title: Getting the most out of LLMs - local: perplexity title: Perplexity of fixed-length models - - title: Chat with models - sections: + title: LLMs + - sections: - local: conversations title: Chat basics - local: chat_templating @@ -94,8 +92,8 @@ title: Template writing - local: chat_extras title: Tools and RAG - - title: Optimization - sections: + title: Chat with models + - sections: - local: perf_torch_compile title: torch.compile - local: perf_infer_gpu_one @@ -106,15 +104,15 @@ title: CPU - local: tf_xla title: XLA + title: Optimization - local: agents title: Agents - local: tools title: Tools -- title: Training - isExpanded: False + title: Inference +- isExpanded: false sections: - - title: Trainer API - sections: + - sections: - local: trainer title: Trainer - local: training @@ -123,8 +121,8 @@ title: Optimizers - local: hpo_train title: Hyperparameter search - - title: Distributed training - sections: + title: Trainer API + - sections: - local: gpu_selection title: GPU selection - local: accelerate @@ -139,8 +137,8 @@ title: Distributed CPUs - local: perf_train_gpu_many title: Parallelism methods - - title: Hardware - sections: + title: Distributed training + - sections: - local: perf_train_gpu_one title: GPU - local: perf_train_cpu @@ -151,12 +149,13 @@ title: Apple Silicon - local: perf_hardware title: Build your own machine + title: Hardware - local: peft title: PEFT - local: model_memory_anatomy title: Model training anatomy -- title: Quantization - isExpanded: False + title: Training +- isExpanded: false sections: - local: quantization/overview title: Overview @@ -196,8 +195,8 @@ title: VPTQ - local: quantization/contribute title: Contribute -- title: Export to production - isExpanded: False + title: Quantization +- isExpanded: false sections: - local: serialization title: ONNX @@ -207,13 +206,11 @@ title: ExecuTorch - local: torchscript title: TorchScript -- title: Resources - isExpanded: False + title: Export to production +- isExpanded: false sections: - - title: Task recipes - sections: - - title: Natural language processing - sections: + - sections: + - sections: - local: tasks/sequence_classification title: Text classification - local: tasks/token_classification @@ -230,14 +227,14 @@ title: Summarization - local: tasks/multiple_choice title: Multiple choice - - title: Audio - sections: + title: Natural language processing + - sections: - local: tasks/audio_classification title: Audio classification - local: tasks/asr title: Automatic speech recognition - - title: Computer vision - sections: + title: Audio + - sections: - local: tasks/image_classification title: Image classification - local: tasks/semantic_segmentation @@ -262,8 +259,8 @@ title: Keypoint detection - local: tasks/knowledge_distillation_for_image_classification title: Knowledge Distillation for Computer Vision - - title: Multimodal - sections: + title: Computer vision + - sections: - local: tasks/image_captioning title: Image captioning - local: tasks/document_question_answering @@ -278,6 +275,8 @@ title: Image-text-to-text - local: tasks/video_text_to_text title: Video-text-to-text + title: Multimodal + title: Task recipes - local: run_scripts title: Training scripts - local: glossary @@ -290,8 +289,8 @@ title: Community resources - local: troubleshooting title: Troubleshoot -- title: Contribute - isExpanded: False + title: Resources +- isExpanded: false sections: - local: contributing title: Contribute to Transformers @@ -299,11 +298,10 @@ title: Transformers model tests - local: pr_checks title: Pull request checks -- title: API - isExpanded: False + title: Contribute +- isExpanded: false sections: - - title: Main classes - sections: + - sections: - local: main_classes/agent title: Agents and Tools - local: model_doc/auto @@ -350,10 +348,9 @@ title: Feature Extractor - local: main_classes/image_processor title: Image Processor - - title: Models - sections: - - title: Text models - sections: + title: Main classes + - sections: + - sections: - local: model_doc/albert title: ALBERT - local: model_doc/bamba @@ -662,8 +659,8 @@ title: Zamba - local: model_doc/zamba2 title: Zamba2 - - title: Vision models - sections: + title: Text models + - sections: - local: model_doc/beit title: BEiT - local: model_doc/bit @@ -790,8 +787,8 @@ title: YOLOS - local: model_doc/zoedepth title: ZoeDepth - - title: Audio models - sections: + title: Vision models + - sections: - local: model_doc/audio-spectrogram-transformer title: Audio Spectrogram Transformer - local: model_doc/bark @@ -860,16 +857,16 @@ title: XLS-R - local: model_doc/xlsr_wav2vec2 title: XLSR-Wav2Vec2 - - title: Video models - sections: + title: Audio models + - sections: - local: model_doc/timesformer title: TimeSformer - local: model_doc/videomae title: VideoMAE - local: model_doc/vivit title: ViViT - - title: Multimodal models - sections: + title: Video models + - sections: - local: model_doc/align title: ALIGN - local: model_doc/altclip @@ -908,6 +905,8 @@ title: Emu3 - local: model_doc/flava title: FLAVA + - local: model_doc/gemma3 + title: Gemma3 - local: model_doc/git title: GIT - local: model_doc/got_ocr2 @@ -1012,14 +1011,14 @@ title: VisualBERT - local: model_doc/xclip title: X-CLIP - - title: Reinforcement learning models - sections: + title: Multimodal models + - sections: - local: model_doc/decision_transformer title: Decision Transformer - local: model_doc/trajectory_transformer title: Trajectory Transformer - - title: Time series models - sections: + title: Reinforcement learning models + - sections: - local: model_doc/autoformer title: Autoformer - local: model_doc/informer @@ -1030,12 +1029,13 @@ title: PatchTST - local: model_doc/time_series_transformer title: Time Series Transformer - - title: Graph models - sections: + title: Time series models + - sections: - local: model_doc/graphormer title: Graphormer - - title: Internal helpers - sections: + title: Graph models + title: Models + - sections: - local: internal/modeling_utils title: Custom Layers and Utilities - local: internal/pipelines_utils @@ -1054,4 +1054,5 @@ title: General Utilities - local: internal/time_series_utils title: Utilities for Time Series - \ No newline at end of file + title: Internal helpers + title: API diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index aff1eb93af..b0a054998c 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -71,7 +71,7 @@ class WrappedFlexAttention: return self._compiled_flex_attention -def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> BlockMask: +def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask": """ Create a block causal document mask for a batch of sequences, both packed and unpacked. Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. @@ -149,7 +149,7 @@ def flex_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Union[torch.Tensor, BlockMask], + attention_mask: Union[torch.Tensor, "BlockMask"], scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 14fa632131..c4cf20c060 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model( is_quantized = hf_quantizer is not None for serialized_param_name, empty_param in state_dict.items(): - if serialized_param_name not in expected_keys: - continue - # serialized_param_name is the raw, serialized name # fixed_param_name is the model's equivalent fixed_param_name, _ = model.rename_key(serialized_param_name) + if fixed_param_name not in expected_keys: + continue + # we need to use serialized_param_name as file pointer is untouched if shard_file.endswith(".safetensors"): param = file_pointer.get_slice(serialized_param_name) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index d5498c8615..500910404b 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -845,7 +845,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): - The device to plcae the 4D attention mask on. + The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`):