[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)
* support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils
This commit is contained in:
@@ -37,6 +37,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.import_utils import is_torch_available
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_aria import AriaConfig, AriaTextConfig
|
||||
@@ -708,6 +709,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@@ -1168,6 +1170,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -1183,7 +1186,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
@@ -1193,10 +1196,12 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
num_logits_to_keep (`int`, *optional*):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1239,7 +1244,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@@ -1324,8 +1330,9 @@ ARIA_INPUTS_DOCSTRING = r"""
|
||||
Whether to output hidden states.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether to return a `ModelOutput` object.
|
||||
num_logits_to_keep (`int`, *optional*, defaults to 0):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`.
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
||||
If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`.
|
||||
Otherwise, slice according to the 1D tensor in the sequence length dimension
|
||||
cache_position (`torch.LongTensor`, *optional*):
|
||||
Cache positions.
|
||||
**loss_kwargs:
|
||||
@@ -1426,6 +1433,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
|
||||
def forward(
|
||||
@@ -1442,7 +1450,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
||||
@@ -1552,7 +1560,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
logits_to_keep=logits_to_keep,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@@ -1584,7 +1592,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
pixel_mask=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
num_logits_to_keep=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
@@ -1593,7 +1601,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.import_utils import is_torch_available
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
|
||||
from ..llama.configuration_llama import LlamaConfig
|
||||
@@ -1222,6 +1223,8 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
@@ -1301,8 +1304,9 @@ ARIA_INPUTS_DOCSTRING = r"""
|
||||
Whether to output hidden states.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether to return a `ModelOutput` object.
|
||||
num_logits_to_keep (`int`, *optional*, defaults to 0):
|
||||
Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`.
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
||||
If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`.
|
||||
Otherwise, slice according to the 1D tensor in the sequence length dimension
|
||||
cache_position (`torch.LongTensor`, *optional*):
|
||||
Cache positions.
|
||||
**loss_kwargs:
|
||||
@@ -1403,6 +1407,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
|
||||
def forward(
|
||||
@@ -1419,7 +1424,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
||||
@@ -1529,7 +1534,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
logits_to_keep=logits_to_keep,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@@ -1561,7 +1566,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
pixel_mask=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
num_logits_to_keep=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
@@ -1570,7 +1575,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user