[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)

* support

* Update modeling_utils.py

* style

* most models

* Other models

* fix-copies

* tests + generation utils
This commit is contained in:
Cyril Vallez
2025-01-23 09:47:54 +01:00
committed by GitHub
parent 8736e91ad6
commit d3af76df58
62 changed files with 603 additions and 315 deletions

View File

@@ -37,6 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import is_torch_available
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_aria import AriaConfig, AriaTextConfig
@@ -708,6 +709,7 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = False
def _init_weights(self, module):
std = self.config.initializer_range
@@ -1168,6 +1170,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -1183,7 +1186,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
@@ -1193,10 +1196,12 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
@@ -1239,7 +1244,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@@ -1324,8 +1330,9 @@ ARIA_INPUTS_DOCSTRING = r"""
Whether to output hidden states.
return_dict (`bool`, *optional*):
Whether to return a `ModelOutput` object.
num_logits_to_keep (`int`, *optional*, defaults to 0):
Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`.
Otherwise, slice according to the 1D tensor in the sequence length dimension
cache_position (`torch.LongTensor`, *optional*):
Cache positions.
**loss_kwargs:
@@ -1426,6 +1433,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
@@ -1442,7 +1450,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
@@ -1552,7 +1560,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
)
logits = outputs[0]
@@ -1584,7 +1592,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
pixel_mask=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
@@ -1593,7 +1601,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
**kwargs,
)

View File

@@ -45,6 +45,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import is_torch_available
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
@@ -1222,6 +1223,8 @@ class AriaTextPreTrainedModel(PreTrainedModel):
class AriaPreTrainedModel(LlamaPreTrainedModel):
_supports_attention_backend = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
@@ -1301,8 +1304,9 @@ ARIA_INPUTS_DOCSTRING = r"""
Whether to output hidden states.
return_dict (`bool`, *optional*):
Whether to return a `ModelOutput` object.
num_logits_to_keep (`int`, *optional*, defaults to 0):
Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`.
Otherwise, slice according to the 1D tensor in the sequence length dimension
cache_position (`torch.LongTensor`, *optional*):
Cache positions.
**loss_kwargs:
@@ -1403,6 +1407,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
@@ -1419,7 +1424,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
@@ -1529,7 +1534,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
)
logits = outputs[0]
@@ -1561,7 +1566,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
pixel_mask=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
@@ -1570,7 +1575,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
**kwargs,
)