symbolic_trace: add past_key_values, llama, sdpa support (#28447)
* torch.fx: add pkv, llama, sdpa support * Update src/transformers/models/opt/modeling_opt.py * remove spaces * trigger ci * use explicit variable names
This commit is contained in:
@@ -292,6 +292,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LlamaModelTester(self)
|
||||
|
||||
@@ -118,7 +118,7 @@ if is_flax_available():
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
@@ -1004,7 +1004,9 @@ class ModelTesterMixin:
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
self.skipTest(
|
||||
f"Either torch.fx is not available, or the model type {config.model_type} is not compatible with torch.fx"
|
||||
)
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
@@ -1060,6 +1062,26 @@ class ModelTesterMixin:
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||||
input_names.append("past_key_values")
|
||||
|
||||
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
|
||||
if "past_key_values" not in inputs:
|
||||
batch_size = inputs[next(iter(inputs))].shape[0]
|
||||
num_heads = model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||
pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
|
||||
inputs["past_key_values"] = pkv
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
@@ -1069,8 +1091,10 @@ class ModelTesterMixin:
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
Reference in New Issue
Block a user