Fix symbolic_trace with kv cache (#28724)
* fix symbolic_trace with kv cache * comment & better test
This commit is contained in:
@@ -765,7 +765,7 @@ class HFTracer(Tracer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _generate_dummy_input(
|
def _generate_dummy_input(
|
||||||
self, model: PreTrainedModel, input_name: str, shape: List[int]
|
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Generates dummy input for model inference recording."""
|
"""Generates dummy input for model inference recording."""
|
||||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||||
@@ -774,6 +774,11 @@ class HFTracer(Tracer):
|
|||||||
device = model.device
|
device = model.device
|
||||||
inputs_dict = {}
|
inputs_dict = {}
|
||||||
|
|
||||||
|
# when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
|
||||||
|
# rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
|
||||||
|
# After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
|
||||||
|
kv_cache_length = 5
|
||||||
|
|
||||||
if input_name in ["labels", "start_positions", "end_positions"]:
|
if input_name in ["labels", "start_positions", "end_positions"]:
|
||||||
batch_size = shape[0]
|
batch_size = shape[0]
|
||||||
if model_class_name in [
|
if model_class_name in [
|
||||||
@@ -883,7 +888,14 @@ class HFTracer(Tracer):
|
|||||||
# Generating big sequence length for audio inputs.
|
# Generating big sequence length for audio inputs.
|
||||||
seq_length = _generate_random_int(low=10000, high=20000)
|
seq_length = _generate_random_int(low=10000, high=20000)
|
||||||
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
|
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
|
||||||
elif "mask" in input_name or "ids" in input_name:
|
elif "mask" in input_name:
|
||||||
|
if "past_key_values" in input_names:
|
||||||
|
mask_shape = [shape[0], shape[1] + kv_cache_length]
|
||||||
|
else:
|
||||||
|
mask_shape = shape
|
||||||
|
|
||||||
|
inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
|
||||||
|
elif "ids" in input_name:
|
||||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||||
elif "past_key_values" in input_name:
|
elif "past_key_values" in input_name:
|
||||||
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||||||
@@ -893,7 +905,7 @@ class HFTracer(Tracer):
|
|||||||
num_heads = model.config.num_attention_heads
|
num_heads = model.config.num_attention_heads
|
||||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
|
||||||
cache_shape = (shape[0], num_heads, 0, head_dim)
|
cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
|
||||||
pkv = tuple(
|
pkv = tuple(
|
||||||
(
|
(
|
||||||
torch.rand(cache_shape, dtype=torch.float, device=device),
|
torch.rand(cache_shape, dtype=torch.float, device=device),
|
||||||
@@ -1095,7 +1107,7 @@ class HFTracer(Tracer):
|
|||||||
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
|
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
|
||||||
("_deserialize_graph_module", "_CodeOnlyModule")
|
("_deserialize_graph_module", "_CodeOnlyModule")
|
||||||
):
|
):
|
||||||
inputs.update(self._generate_dummy_input(root, input_name, shape))
|
inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not generate input named {input_name} for because root is not a"
|
f"Could not generate input named {input_name} for because root is not a"
|
||||||
|
|||||||
@@ -1053,7 +1053,9 @@ class ModelTesterMixin:
|
|||||||
model.eval()
|
model.eval()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||||
|
|
||||||
try:
|
# We may want to test several inputs (various shapes, etc.).
|
||||||
|
inputs_to_test = [inputs]
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
labels = inputs.get("labels", None)
|
labels = inputs.get("labels", None)
|
||||||
@@ -1067,14 +1069,6 @@ class ModelTesterMixin:
|
|||||||
]
|
]
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
input_names.append("labels")
|
input_names.append("labels")
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
|
||||||
input_names = list(filtered_inputs.keys())
|
|
||||||
|
|
||||||
model_output = model(**filtered_inputs)
|
|
||||||
|
|
||||||
traced_model = symbolic_trace(model, input_names)
|
|
||||||
traced_output = traced_model(**filtered_inputs)
|
|
||||||
else:
|
else:
|
||||||
input_names = [
|
input_names = [
|
||||||
"attention_mask",
|
"attention_mask",
|
||||||
@@ -1108,7 +1102,7 @@ class ModelTesterMixin:
|
|||||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
|
||||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||||
pkv = tuple(
|
empty_pkv = tuple(
|
||||||
(
|
(
|
||||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||||
@@ -1116,9 +1110,30 @@ class ModelTesterMixin:
|
|||||||
for i in range(model.config.num_hidden_layers)
|
for i in range(model.config.num_hidden_layers)
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs["past_key_values"] = pkv
|
cache_length = 9
|
||||||
|
cache_shape = (batch_size, num_heads, cache_length, head_dim)
|
||||||
|
non_empty_pkv = tuple(
|
||||||
|
(
|
||||||
|
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||||
|
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||||
|
)
|
||||||
|
for i in range(model.config.num_hidden_layers)
|
||||||
|
)
|
||||||
|
|
||||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
inps = copy.deepcopy(inputs_to_test[0])
|
||||||
|
|
||||||
|
inputs_to_test[0]["past_key_values"] = empty_pkv
|
||||||
|
|
||||||
|
inps["past_key_values"] = non_empty_pkv
|
||||||
|
inputs_to_test.append(inps)
|
||||||
|
|
||||||
|
past_mask = torch.ones(batch_size, cache_length, device=torch_device, dtype=torch.float)
|
||||||
|
inputs_to_test[1]["attention_mask"] = torch.cat(
|
||||||
|
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
for inps in inputs_to_test:
|
||||||
|
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
|
||||||
input_names = list(filtered_inputs.keys())
|
input_names = list(filtered_inputs.keys())
|
||||||
|
|
||||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||||
@@ -1132,9 +1147,6 @@ class ModelTesterMixin:
|
|||||||
traced_output = traced_model(**filtered_inputs)
|
traced_output = traced_model(**filtered_inputs)
|
||||||
model_output = model(**filtered_inputs)
|
model_output = model(**filtered_inputs)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.fail(f"Couldn't trace module: {e}")
|
|
||||||
|
|
||||||
def flatten_output(output):
|
def flatten_output(output):
|
||||||
flatten = []
|
flatten = []
|
||||||
for x in output:
|
for x in output:
|
||||||
|
|||||||
Reference in New Issue
Block a user