Test: generate with torch.compile(model.forward) as a fast test (#34544)

This commit is contained in:
Joao Gante
2025-01-28 14:10:38 +00:00
committed by GitHub
parent f48ecd7608
commit ece8c42488
25 changed files with 105 additions and 53 deletions

View File

@@ -708,7 +708,7 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False
def _init_weights(self, module):
@@ -1561,6 +1561,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
cache_position=cache_position,
)
logits = outputs[0]

View File

@@ -1223,6 +1223,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
class AriaPreTrainedModel(LlamaPreTrainedModel):
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False
def _init_weights(self, module):
@@ -1535,6 +1536,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
cache_position=cache_position,
)
logits = outputs[0]