Test: generate with torch.compile(model.forward) as a fast test (#34544)
This commit is contained in:
@@ -708,7 +708,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
@@ -1561,6 +1561,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
logits_to_keep=logits_to_keep,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
@@ -1223,6 +1223,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
@@ -1535,6 +1536,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
logits_to_keep=logits_to_keep,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user