fix block mask typing (#36661)
* fix block mask typing * updated Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * gemma * fix --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -1,16 +1,14 @@
|
|||||||
- title: Get started
|
- sections:
|
||||||
sections:
|
|
||||||
- local: index
|
- local: index
|
||||||
title: Transformers
|
title: Transformers
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation
|
title: Installation
|
||||||
- local: quicktour
|
- local: quicktour
|
||||||
title: Quickstart
|
title: Quickstart
|
||||||
- title: Base classes
|
title: Get started
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
|
||||||
- title: Models
|
|
||||||
sections:
|
sections:
|
||||||
|
- sections:
|
||||||
- local: models
|
- local: models
|
||||||
title: Loading models
|
title: Loading models
|
||||||
- local: custom_models
|
- local: custom_models
|
||||||
@@ -31,8 +29,8 @@
|
|||||||
title: The Transformer model family
|
title: The Transformer model family
|
||||||
- local: attention
|
- local: attention
|
||||||
title: Attention mechanisms
|
title: Attention mechanisms
|
||||||
- title: Preprocessors
|
title: Models
|
||||||
sections:
|
- sections:
|
||||||
- local: fast_tokenizers
|
- local: fast_tokenizers
|
||||||
title: Tokenizers
|
title: Tokenizers
|
||||||
- local: image_processors
|
- local: image_processors
|
||||||
@@ -47,11 +45,11 @@
|
|||||||
title: Summary of the tokenizers
|
title: Summary of the tokenizers
|
||||||
- local: pad_truncation
|
- local: pad_truncation
|
||||||
title: Padding and truncation
|
title: Padding and truncation
|
||||||
- title: Inference
|
title: Preprocessors
|
||||||
isExpanded: False
|
title: Base classes
|
||||||
sections:
|
- isExpanded: false
|
||||||
- title: Pipeline API
|
|
||||||
sections:
|
sections:
|
||||||
|
- sections:
|
||||||
- local: pipeline_tutorial
|
- local: pipeline_tutorial
|
||||||
title: Pipeline
|
title: Pipeline
|
||||||
- local: pipeline_gradio
|
- local: pipeline_gradio
|
||||||
@@ -60,8 +58,8 @@
|
|||||||
title: Web server inference
|
title: Web server inference
|
||||||
- local: add_new_pipeline
|
- local: add_new_pipeline
|
||||||
title: Adding a new pipeline
|
title: Adding a new pipeline
|
||||||
- title: LLMs
|
title: Pipeline API
|
||||||
sections:
|
- sections:
|
||||||
- local: llm_tutorial
|
- local: llm_tutorial
|
||||||
title: Text generation
|
title: Text generation
|
||||||
- local: generation_strategies
|
- local: generation_strategies
|
||||||
@@ -82,8 +80,8 @@
|
|||||||
title: Getting the most out of LLMs
|
title: Getting the most out of LLMs
|
||||||
- local: perplexity
|
- local: perplexity
|
||||||
title: Perplexity of fixed-length models
|
title: Perplexity of fixed-length models
|
||||||
- title: Chat with models
|
title: LLMs
|
||||||
sections:
|
- sections:
|
||||||
- local: conversations
|
- local: conversations
|
||||||
title: Chat basics
|
title: Chat basics
|
||||||
- local: chat_templating
|
- local: chat_templating
|
||||||
@@ -94,8 +92,8 @@
|
|||||||
title: Template writing
|
title: Template writing
|
||||||
- local: chat_extras
|
- local: chat_extras
|
||||||
title: Tools and RAG
|
title: Tools and RAG
|
||||||
- title: Optimization
|
title: Chat with models
|
||||||
sections:
|
- sections:
|
||||||
- local: perf_torch_compile
|
- local: perf_torch_compile
|
||||||
title: torch.compile
|
title: torch.compile
|
||||||
- local: perf_infer_gpu_one
|
- local: perf_infer_gpu_one
|
||||||
@@ -106,15 +104,15 @@
|
|||||||
title: CPU
|
title: CPU
|
||||||
- local: tf_xla
|
- local: tf_xla
|
||||||
title: XLA
|
title: XLA
|
||||||
|
title: Optimization
|
||||||
- local: agents
|
- local: agents
|
||||||
title: Agents
|
title: Agents
|
||||||
- local: tools
|
- local: tools
|
||||||
title: Tools
|
title: Tools
|
||||||
- title: Training
|
title: Inference
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
|
||||||
- title: Trainer API
|
|
||||||
sections:
|
sections:
|
||||||
|
- sections:
|
||||||
- local: trainer
|
- local: trainer
|
||||||
title: Trainer
|
title: Trainer
|
||||||
- local: training
|
- local: training
|
||||||
@@ -123,8 +121,8 @@
|
|||||||
title: Optimizers
|
title: Optimizers
|
||||||
- local: hpo_train
|
- local: hpo_train
|
||||||
title: Hyperparameter search
|
title: Hyperparameter search
|
||||||
- title: Distributed training
|
title: Trainer API
|
||||||
sections:
|
- sections:
|
||||||
- local: gpu_selection
|
- local: gpu_selection
|
||||||
title: GPU selection
|
title: GPU selection
|
||||||
- local: accelerate
|
- local: accelerate
|
||||||
@@ -139,8 +137,8 @@
|
|||||||
title: Distributed CPUs
|
title: Distributed CPUs
|
||||||
- local: perf_train_gpu_many
|
- local: perf_train_gpu_many
|
||||||
title: Parallelism methods
|
title: Parallelism methods
|
||||||
- title: Hardware
|
title: Distributed training
|
||||||
sections:
|
- sections:
|
||||||
- local: perf_train_gpu_one
|
- local: perf_train_gpu_one
|
||||||
title: GPU
|
title: GPU
|
||||||
- local: perf_train_cpu
|
- local: perf_train_cpu
|
||||||
@@ -151,12 +149,13 @@
|
|||||||
title: Apple Silicon
|
title: Apple Silicon
|
||||||
- local: perf_hardware
|
- local: perf_hardware
|
||||||
title: Build your own machine
|
title: Build your own machine
|
||||||
|
title: Hardware
|
||||||
- local: peft
|
- local: peft
|
||||||
title: PEFT
|
title: PEFT
|
||||||
- local: model_memory_anatomy
|
- local: model_memory_anatomy
|
||||||
title: Model training anatomy
|
title: Model training anatomy
|
||||||
- title: Quantization
|
title: Training
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
sections:
|
||||||
- local: quantization/overview
|
- local: quantization/overview
|
||||||
title: Overview
|
title: Overview
|
||||||
@@ -196,8 +195,8 @@
|
|||||||
title: VPTQ
|
title: VPTQ
|
||||||
- local: quantization/contribute
|
- local: quantization/contribute
|
||||||
title: Contribute
|
title: Contribute
|
||||||
- title: Export to production
|
title: Quantization
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
sections:
|
||||||
- local: serialization
|
- local: serialization
|
||||||
title: ONNX
|
title: ONNX
|
||||||
@@ -207,13 +206,11 @@
|
|||||||
title: ExecuTorch
|
title: ExecuTorch
|
||||||
- local: torchscript
|
- local: torchscript
|
||||||
title: TorchScript
|
title: TorchScript
|
||||||
- title: Resources
|
title: Export to production
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
|
||||||
- title: Task recipes
|
|
||||||
sections:
|
|
||||||
- title: Natural language processing
|
|
||||||
sections:
|
sections:
|
||||||
|
- sections:
|
||||||
|
- sections:
|
||||||
- local: tasks/sequence_classification
|
- local: tasks/sequence_classification
|
||||||
title: Text classification
|
title: Text classification
|
||||||
- local: tasks/token_classification
|
- local: tasks/token_classification
|
||||||
@@ -230,14 +227,14 @@
|
|||||||
title: Summarization
|
title: Summarization
|
||||||
- local: tasks/multiple_choice
|
- local: tasks/multiple_choice
|
||||||
title: Multiple choice
|
title: Multiple choice
|
||||||
- title: Audio
|
title: Natural language processing
|
||||||
sections:
|
- sections:
|
||||||
- local: tasks/audio_classification
|
- local: tasks/audio_classification
|
||||||
title: Audio classification
|
title: Audio classification
|
||||||
- local: tasks/asr
|
- local: tasks/asr
|
||||||
title: Automatic speech recognition
|
title: Automatic speech recognition
|
||||||
- title: Computer vision
|
title: Audio
|
||||||
sections:
|
- sections:
|
||||||
- local: tasks/image_classification
|
- local: tasks/image_classification
|
||||||
title: Image classification
|
title: Image classification
|
||||||
- local: tasks/semantic_segmentation
|
- local: tasks/semantic_segmentation
|
||||||
@@ -262,8 +259,8 @@
|
|||||||
title: Keypoint detection
|
title: Keypoint detection
|
||||||
- local: tasks/knowledge_distillation_for_image_classification
|
- local: tasks/knowledge_distillation_for_image_classification
|
||||||
title: Knowledge Distillation for Computer Vision
|
title: Knowledge Distillation for Computer Vision
|
||||||
- title: Multimodal
|
title: Computer vision
|
||||||
sections:
|
- sections:
|
||||||
- local: tasks/image_captioning
|
- local: tasks/image_captioning
|
||||||
title: Image captioning
|
title: Image captioning
|
||||||
- local: tasks/document_question_answering
|
- local: tasks/document_question_answering
|
||||||
@@ -278,6 +275,8 @@
|
|||||||
title: Image-text-to-text
|
title: Image-text-to-text
|
||||||
- local: tasks/video_text_to_text
|
- local: tasks/video_text_to_text
|
||||||
title: Video-text-to-text
|
title: Video-text-to-text
|
||||||
|
title: Multimodal
|
||||||
|
title: Task recipes
|
||||||
- local: run_scripts
|
- local: run_scripts
|
||||||
title: Training scripts
|
title: Training scripts
|
||||||
- local: glossary
|
- local: glossary
|
||||||
@@ -290,8 +289,8 @@
|
|||||||
title: Community resources
|
title: Community resources
|
||||||
- local: troubleshooting
|
- local: troubleshooting
|
||||||
title: Troubleshoot
|
title: Troubleshoot
|
||||||
- title: Contribute
|
title: Resources
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
sections:
|
||||||
- local: contributing
|
- local: contributing
|
||||||
title: Contribute to Transformers
|
title: Contribute to Transformers
|
||||||
@@ -299,11 +298,10 @@
|
|||||||
title: Transformers model tests
|
title: Transformers model tests
|
||||||
- local: pr_checks
|
- local: pr_checks
|
||||||
title: Pull request checks
|
title: Pull request checks
|
||||||
- title: API
|
title: Contribute
|
||||||
isExpanded: False
|
- isExpanded: false
|
||||||
sections:
|
|
||||||
- title: Main classes
|
|
||||||
sections:
|
sections:
|
||||||
|
- sections:
|
||||||
- local: main_classes/agent
|
- local: main_classes/agent
|
||||||
title: Agents and Tools
|
title: Agents and Tools
|
||||||
- local: model_doc/auto
|
- local: model_doc/auto
|
||||||
@@ -350,10 +348,9 @@
|
|||||||
title: Feature Extractor
|
title: Feature Extractor
|
||||||
- local: main_classes/image_processor
|
- local: main_classes/image_processor
|
||||||
title: Image Processor
|
title: Image Processor
|
||||||
- title: Models
|
title: Main classes
|
||||||
sections:
|
- sections:
|
||||||
- title: Text models
|
- sections:
|
||||||
sections:
|
|
||||||
- local: model_doc/albert
|
- local: model_doc/albert
|
||||||
title: ALBERT
|
title: ALBERT
|
||||||
- local: model_doc/bamba
|
- local: model_doc/bamba
|
||||||
@@ -662,8 +659,8 @@
|
|||||||
title: Zamba
|
title: Zamba
|
||||||
- local: model_doc/zamba2
|
- local: model_doc/zamba2
|
||||||
title: Zamba2
|
title: Zamba2
|
||||||
- title: Vision models
|
title: Text models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/beit
|
- local: model_doc/beit
|
||||||
title: BEiT
|
title: BEiT
|
||||||
- local: model_doc/bit
|
- local: model_doc/bit
|
||||||
@@ -790,8 +787,8 @@
|
|||||||
title: YOLOS
|
title: YOLOS
|
||||||
- local: model_doc/zoedepth
|
- local: model_doc/zoedepth
|
||||||
title: ZoeDepth
|
title: ZoeDepth
|
||||||
- title: Audio models
|
title: Vision models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/audio-spectrogram-transformer
|
- local: model_doc/audio-spectrogram-transformer
|
||||||
title: Audio Spectrogram Transformer
|
title: Audio Spectrogram Transformer
|
||||||
- local: model_doc/bark
|
- local: model_doc/bark
|
||||||
@@ -860,16 +857,16 @@
|
|||||||
title: XLS-R
|
title: XLS-R
|
||||||
- local: model_doc/xlsr_wav2vec2
|
- local: model_doc/xlsr_wav2vec2
|
||||||
title: XLSR-Wav2Vec2
|
title: XLSR-Wav2Vec2
|
||||||
- title: Video models
|
title: Audio models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/timesformer
|
- local: model_doc/timesformer
|
||||||
title: TimeSformer
|
title: TimeSformer
|
||||||
- local: model_doc/videomae
|
- local: model_doc/videomae
|
||||||
title: VideoMAE
|
title: VideoMAE
|
||||||
- local: model_doc/vivit
|
- local: model_doc/vivit
|
||||||
title: ViViT
|
title: ViViT
|
||||||
- title: Multimodal models
|
title: Video models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/align
|
- local: model_doc/align
|
||||||
title: ALIGN
|
title: ALIGN
|
||||||
- local: model_doc/altclip
|
- local: model_doc/altclip
|
||||||
@@ -908,6 +905,8 @@
|
|||||||
title: Emu3
|
title: Emu3
|
||||||
- local: model_doc/flava
|
- local: model_doc/flava
|
||||||
title: FLAVA
|
title: FLAVA
|
||||||
|
- local: model_doc/gemma3
|
||||||
|
title: Gemma3
|
||||||
- local: model_doc/git
|
- local: model_doc/git
|
||||||
title: GIT
|
title: GIT
|
||||||
- local: model_doc/got_ocr2
|
- local: model_doc/got_ocr2
|
||||||
@@ -1012,14 +1011,14 @@
|
|||||||
title: VisualBERT
|
title: VisualBERT
|
||||||
- local: model_doc/xclip
|
- local: model_doc/xclip
|
||||||
title: X-CLIP
|
title: X-CLIP
|
||||||
- title: Reinforcement learning models
|
title: Multimodal models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/decision_transformer
|
- local: model_doc/decision_transformer
|
||||||
title: Decision Transformer
|
title: Decision Transformer
|
||||||
- local: model_doc/trajectory_transformer
|
- local: model_doc/trajectory_transformer
|
||||||
title: Trajectory Transformer
|
title: Trajectory Transformer
|
||||||
- title: Time series models
|
title: Reinforcement learning models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/autoformer
|
- local: model_doc/autoformer
|
||||||
title: Autoformer
|
title: Autoformer
|
||||||
- local: model_doc/informer
|
- local: model_doc/informer
|
||||||
@@ -1030,12 +1029,13 @@
|
|||||||
title: PatchTST
|
title: PatchTST
|
||||||
- local: model_doc/time_series_transformer
|
- local: model_doc/time_series_transformer
|
||||||
title: Time Series Transformer
|
title: Time Series Transformer
|
||||||
- title: Graph models
|
title: Time series models
|
||||||
sections:
|
- sections:
|
||||||
- local: model_doc/graphormer
|
- local: model_doc/graphormer
|
||||||
title: Graphormer
|
title: Graphormer
|
||||||
- title: Internal helpers
|
title: Graph models
|
||||||
sections:
|
title: Models
|
||||||
|
- sections:
|
||||||
- local: internal/modeling_utils
|
- local: internal/modeling_utils
|
||||||
title: Custom Layers and Utilities
|
title: Custom Layers and Utilities
|
||||||
- local: internal/pipelines_utils
|
- local: internal/pipelines_utils
|
||||||
@@ -1054,4 +1054,5 @@
|
|||||||
title: General Utilities
|
title: General Utilities
|
||||||
- local: internal/time_series_utils
|
- local: internal/time_series_utils
|
||||||
title: Utilities for Time Series
|
title: Utilities for Time Series
|
||||||
|
title: Internal helpers
|
||||||
|
title: API
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class WrappedFlexAttention:
|
|||||||
return self._compiled_flex_attention
|
return self._compiled_flex_attention
|
||||||
|
|
||||||
|
|
||||||
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> BlockMask:
|
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
|
||||||
"""
|
"""
|
||||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
||||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||||
@@ -149,7 +149,7 @@ def flex_attention_forward(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
attention_mask: Union[torch.Tensor, BlockMask],
|
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||||
scaling: Optional[float] = None,
|
scaling: Optional[float] = None,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@@ -849,13 +849,13 @@ def _load_state_dict_into_meta_model(
|
|||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
|
|
||||||
for serialized_param_name, empty_param in state_dict.items():
|
for serialized_param_name, empty_param in state_dict.items():
|
||||||
if serialized_param_name not in expected_keys:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# serialized_param_name is the raw, serialized name
|
# serialized_param_name is the raw, serialized name
|
||||||
# fixed_param_name is the model's equivalent
|
# fixed_param_name is the model's equivalent
|
||||||
fixed_param_name, _ = model.rename_key(serialized_param_name)
|
fixed_param_name, _ = model.rename_key(serialized_param_name)
|
||||||
|
|
||||||
|
if fixed_param_name not in expected_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
# we need to use serialized_param_name as file pointer is untouched
|
# we need to use serialized_param_name as file pointer is untouched
|
||||||
if shard_file.endswith(".safetensors"):
|
if shard_file.endswith(".safetensors"):
|
||||||
param = file_pointer.get_slice(serialized_param_name)
|
param = file_pointer.get_slice(serialized_param_name)
|
||||||
|
|||||||
@@ -845,7 +845,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
|
|||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
The device to plcae the 4D attention mask on.
|
The device to place the 4D attention mask on.
|
||||||
cache_position (`torch.Tensor`):
|
cache_position (`torch.Tensor`):
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
|
|||||||
Reference in New Issue
Block a user