fix block mask typing (#36661)

* fix block mask typing

* updated

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* gemma

* fix

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2025-03-12 11:29:11 +01:00
committed by GitHub
parent 89f6956015
commit 2829013d2d
4 changed files with 76 additions and 75 deletions

View File

@@ -1,16 +1,14 @@
- title: Get started - sections:
sections:
- local: index - local: index
title: Transformers title: Transformers
- local: installation - local: installation
title: Installation title: Installation
- local: quicktour - local: quicktour
title: Quickstart title: Quickstart
- title: Base classes title: Get started
isExpanded: False - isExpanded: false
sections:
- title: Models
sections: sections:
- sections:
- local: models - local: models
title: Loading models title: Loading models
- local: custom_models - local: custom_models
@@ -31,8 +29,8 @@
title: The Transformer model family title: The Transformer model family
- local: attention - local: attention
title: Attention mechanisms title: Attention mechanisms
- title: Preprocessors title: Models
sections: - sections:
- local: fast_tokenizers - local: fast_tokenizers
title: Tokenizers title: Tokenizers
- local: image_processors - local: image_processors
@@ -47,11 +45,11 @@
title: Summary of the tokenizers title: Summary of the tokenizers
- local: pad_truncation - local: pad_truncation
title: Padding and truncation title: Padding and truncation
- title: Inference title: Preprocessors
isExpanded: False title: Base classes
sections: - isExpanded: false
- title: Pipeline API
sections: sections:
- sections:
- local: pipeline_tutorial - local: pipeline_tutorial
title: Pipeline title: Pipeline
- local: pipeline_gradio - local: pipeline_gradio
@@ -60,8 +58,8 @@
title: Web server inference title: Web server inference
- local: add_new_pipeline - local: add_new_pipeline
title: Adding a new pipeline title: Adding a new pipeline
- title: LLMs title: Pipeline API
sections: - sections:
- local: llm_tutorial - local: llm_tutorial
title: Text generation title: Text generation
- local: generation_strategies - local: generation_strategies
@@ -82,8 +80,8 @@
title: Getting the most out of LLMs title: Getting the most out of LLMs
- local: perplexity - local: perplexity
title: Perplexity of fixed-length models title: Perplexity of fixed-length models
- title: Chat with models title: LLMs
sections: - sections:
- local: conversations - local: conversations
title: Chat basics title: Chat basics
- local: chat_templating - local: chat_templating
@@ -94,8 +92,8 @@
title: Template writing title: Template writing
- local: chat_extras - local: chat_extras
title: Tools and RAG title: Tools and RAG
- title: Optimization title: Chat with models
sections: - sections:
- local: perf_torch_compile - local: perf_torch_compile
title: torch.compile title: torch.compile
- local: perf_infer_gpu_one - local: perf_infer_gpu_one
@@ -106,15 +104,15 @@
title: CPU title: CPU
- local: tf_xla - local: tf_xla
title: XLA title: XLA
title: Optimization
- local: agents - local: agents
title: Agents title: Agents
- local: tools - local: tools
title: Tools title: Tools
- title: Training title: Inference
isExpanded: False - isExpanded: false
sections:
- title: Trainer API
sections: sections:
- sections:
- local: trainer - local: trainer
title: Trainer title: Trainer
- local: training - local: training
@@ -123,8 +121,8 @@
title: Optimizers title: Optimizers
- local: hpo_train - local: hpo_train
title: Hyperparameter search title: Hyperparameter search
- title: Distributed training title: Trainer API
sections: - sections:
- local: gpu_selection - local: gpu_selection
title: GPU selection title: GPU selection
- local: accelerate - local: accelerate
@@ -139,8 +137,8 @@
title: Distributed CPUs title: Distributed CPUs
- local: perf_train_gpu_many - local: perf_train_gpu_many
title: Parallelism methods title: Parallelism methods
- title: Hardware title: Distributed training
sections: - sections:
- local: perf_train_gpu_one - local: perf_train_gpu_one
title: GPU title: GPU
- local: perf_train_cpu - local: perf_train_cpu
@@ -151,12 +149,13 @@
title: Apple Silicon title: Apple Silicon
- local: perf_hardware - local: perf_hardware
title: Build your own machine title: Build your own machine
title: Hardware
- local: peft - local: peft
title: PEFT title: PEFT
- local: model_memory_anatomy - local: model_memory_anatomy
title: Model training anatomy title: Model training anatomy
- title: Quantization title: Training
isExpanded: False - isExpanded: false
sections: sections:
- local: quantization/overview - local: quantization/overview
title: Overview title: Overview
@@ -196,8 +195,8 @@
title: VPTQ title: VPTQ
- local: quantization/contribute - local: quantization/contribute
title: Contribute title: Contribute
- title: Export to production title: Quantization
isExpanded: False - isExpanded: false
sections: sections:
- local: serialization - local: serialization
title: ONNX title: ONNX
@@ -207,13 +206,11 @@
title: ExecuTorch title: ExecuTorch
- local: torchscript - local: torchscript
title: TorchScript title: TorchScript
- title: Resources title: Export to production
isExpanded: False - isExpanded: false
sections:
- title: Task recipes
sections:
- title: Natural language processing
sections: sections:
- sections:
- sections:
- local: tasks/sequence_classification - local: tasks/sequence_classification
title: Text classification title: Text classification
- local: tasks/token_classification - local: tasks/token_classification
@@ -230,14 +227,14 @@
title: Summarization title: Summarization
- local: tasks/multiple_choice - local: tasks/multiple_choice
title: Multiple choice title: Multiple choice
- title: Audio title: Natural language processing
sections: - sections:
- local: tasks/audio_classification - local: tasks/audio_classification
title: Audio classification title: Audio classification
- local: tasks/asr - local: tasks/asr
title: Automatic speech recognition title: Automatic speech recognition
- title: Computer vision title: Audio
sections: - sections:
- local: tasks/image_classification - local: tasks/image_classification
title: Image classification title: Image classification
- local: tasks/semantic_segmentation - local: tasks/semantic_segmentation
@@ -262,8 +259,8 @@
title: Keypoint detection title: Keypoint detection
- local: tasks/knowledge_distillation_for_image_classification - local: tasks/knowledge_distillation_for_image_classification
title: Knowledge Distillation for Computer Vision title: Knowledge Distillation for Computer Vision
- title: Multimodal title: Computer vision
sections: - sections:
- local: tasks/image_captioning - local: tasks/image_captioning
title: Image captioning title: Image captioning
- local: tasks/document_question_answering - local: tasks/document_question_answering
@@ -278,6 +275,8 @@
title: Image-text-to-text title: Image-text-to-text
- local: tasks/video_text_to_text - local: tasks/video_text_to_text
title: Video-text-to-text title: Video-text-to-text
title: Multimodal
title: Task recipes
- local: run_scripts - local: run_scripts
title: Training scripts title: Training scripts
- local: glossary - local: glossary
@@ -290,8 +289,8 @@
title: Community resources title: Community resources
- local: troubleshooting - local: troubleshooting
title: Troubleshoot title: Troubleshoot
- title: Contribute title: Resources
isExpanded: False - isExpanded: false
sections: sections:
- local: contributing - local: contributing
title: Contribute to Transformers title: Contribute to Transformers
@@ -299,11 +298,10 @@
title: Transformers model tests title: Transformers model tests
- local: pr_checks - local: pr_checks
title: Pull request checks title: Pull request checks
- title: API title: Contribute
isExpanded: False - isExpanded: false
sections:
- title: Main classes
sections: sections:
- sections:
- local: main_classes/agent - local: main_classes/agent
title: Agents and Tools title: Agents and Tools
- local: model_doc/auto - local: model_doc/auto
@@ -350,10 +348,9 @@
title: Feature Extractor title: Feature Extractor
- local: main_classes/image_processor - local: main_classes/image_processor
title: Image Processor title: Image Processor
- title: Models title: Main classes
sections: - sections:
- title: Text models - sections:
sections:
- local: model_doc/albert - local: model_doc/albert
title: ALBERT title: ALBERT
- local: model_doc/bamba - local: model_doc/bamba
@@ -662,8 +659,8 @@
title: Zamba title: Zamba
- local: model_doc/zamba2 - local: model_doc/zamba2
title: Zamba2 title: Zamba2
- title: Vision models title: Text models
sections: - sections:
- local: model_doc/beit - local: model_doc/beit
title: BEiT title: BEiT
- local: model_doc/bit - local: model_doc/bit
@@ -790,8 +787,8 @@
title: YOLOS title: YOLOS
- local: model_doc/zoedepth - local: model_doc/zoedepth
title: ZoeDepth title: ZoeDepth
- title: Audio models title: Vision models
sections: - sections:
- local: model_doc/audio-spectrogram-transformer - local: model_doc/audio-spectrogram-transformer
title: Audio Spectrogram Transformer title: Audio Spectrogram Transformer
- local: model_doc/bark - local: model_doc/bark
@@ -860,16 +857,16 @@
title: XLS-R title: XLS-R
- local: model_doc/xlsr_wav2vec2 - local: model_doc/xlsr_wav2vec2
title: XLSR-Wav2Vec2 title: XLSR-Wav2Vec2
- title: Video models title: Audio models
sections: - sections:
- local: model_doc/timesformer - local: model_doc/timesformer
title: TimeSformer title: TimeSformer
- local: model_doc/videomae - local: model_doc/videomae
title: VideoMAE title: VideoMAE
- local: model_doc/vivit - local: model_doc/vivit
title: ViViT title: ViViT
- title: Multimodal models title: Video models
sections: - sections:
- local: model_doc/align - local: model_doc/align
title: ALIGN title: ALIGN
- local: model_doc/altclip - local: model_doc/altclip
@@ -908,6 +905,8 @@
title: Emu3 title: Emu3
- local: model_doc/flava - local: model_doc/flava
title: FLAVA title: FLAVA
- local: model_doc/gemma3
title: Gemma3
- local: model_doc/git - local: model_doc/git
title: GIT title: GIT
- local: model_doc/got_ocr2 - local: model_doc/got_ocr2
@@ -1012,14 +1011,14 @@
title: VisualBERT title: VisualBERT
- local: model_doc/xclip - local: model_doc/xclip
title: X-CLIP title: X-CLIP
- title: Reinforcement learning models title: Multimodal models
sections: - sections:
- local: model_doc/decision_transformer - local: model_doc/decision_transformer
title: Decision Transformer title: Decision Transformer
- local: model_doc/trajectory_transformer - local: model_doc/trajectory_transformer
title: Trajectory Transformer title: Trajectory Transformer
- title: Time series models title: Reinforcement learning models
sections: - sections:
- local: model_doc/autoformer - local: model_doc/autoformer
title: Autoformer title: Autoformer
- local: model_doc/informer - local: model_doc/informer
@@ -1030,12 +1029,13 @@
title: PatchTST title: PatchTST
- local: model_doc/time_series_transformer - local: model_doc/time_series_transformer
title: Time Series Transformer title: Time Series Transformer
- title: Graph models title: Time series models
sections: - sections:
- local: model_doc/graphormer - local: model_doc/graphormer
title: Graphormer title: Graphormer
- title: Internal helpers title: Graph models
sections: title: Models
- sections:
- local: internal/modeling_utils - local: internal/modeling_utils
title: Custom Layers and Utilities title: Custom Layers and Utilities
- local: internal/pipelines_utils - local: internal/pipelines_utils
@@ -1054,4 +1054,5 @@
title: General Utilities title: General Utilities
- local: internal/time_series_utils - local: internal/time_series_utils
title: Utilities for Time Series title: Utilities for Time Series
title: Internal helpers
title: API

View File

@@ -71,7 +71,7 @@ class WrappedFlexAttention:
return self._compiled_flex_attention return self._compiled_flex_attention
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> BlockMask: def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
""" """
Create a block causal document mask for a batch of sequences, both packed and unpacked. Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
@@ -149,7 +149,7 @@ def flex_attention_forward(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attention_mask: Union[torch.Tensor, BlockMask], attention_mask: Union[torch.Tensor, "BlockMask"],
scaling: Optional[float] = None, scaling: Optional[float] = None,
softcap: Optional[float] = None, softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,

View File

@@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
for serialized_param_name, empty_param in state_dict.items(): for serialized_param_name, empty_param in state_dict.items():
if serialized_param_name not in expected_keys:
continue
# serialized_param_name is the raw, serialized name # serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent # fixed_param_name is the model's equivalent
fixed_param_name, _ = model.rename_key(serialized_param_name) fixed_param_name, _ = model.rename_key(serialized_param_name)
if fixed_param_name not in expected_keys:
continue
# we need to use serialized_param_name as file pointer is untouched # we need to use serialized_param_name as file pointer is untouched
if shard_file.endswith(".safetensors"): if shard_file.endswith(".safetensors"):
param = file_pointer.get_slice(serialized_param_name) param = file_pointer.get_slice(serialized_param_name)

View File

@@ -845,7 +845,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
dtype (`torch.dtype`): dtype (`torch.dtype`):
The dtype to use for the 4D attention mask. The dtype to use for the 4D attention mask.
device (`torch.device`): device (`torch.device`):
The device to plcae the 4D attention mask on. The device to place the 4D attention mask on.
cache_position (`torch.Tensor`): cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`): batch_size (`torch.Tensor`):