symbolic_trace: add past_key_values, llama, sdpa support (#28447)

* torch.fx: add pkv, llama, sdpa support

* Update src/transformers/models/opt/modeling_opt.py

* remove spaces

* trigger ci

* use explicit variable names
This commit is contained in:
fxmarty
2024-01-17 11:50:53 +01:00
committed by GitHub
parent 09eb11a1bd
commit a6adc05e6b
4 changed files with 68 additions and 11 deletions

View File

@@ -118,7 +118,7 @@ if is_flax_available():
)
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
def _config_zero_init(config):
@@ -1004,7 +1004,9 @@ class ModelTesterMixin:
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
self.skipTest(
f"Either torch.fx is not available, or the model type {config.model_type} is not compatible with torch.fx"
)
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
@@ -1060,6 +1062,26 @@ class ModelTesterMixin:
if end_positions is not None:
input_names.append("end_positions")
if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
input_names.append("past_key_values")
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
if "past_key_values" not in inputs:
batch_size = inputs[next(iter(inputs))].shape[0]
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
cache_shape = (batch_size, num_heads, 0, head_dim)
pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
)
for i in range(model.config.num_hidden_layers)
)
inputs["past_key_values"] = pkv
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
@@ -1069,8 +1091,10 @@ class ModelTesterMixin:
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)
with torch.no_grad():
traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs)
except Exception as e:
self.fail(f"Couldn't trace module: {e}")