Fix symbolic_trace with kv cache (#28724)
* fix symbolic_trace with kv cache * comment & better test
This commit is contained in:
@@ -1053,132 +1053,144 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
# We may want to test several inputs (various shapes, etc.).
|
||||
inputs_to_test = [inputs]
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
labels = inputs.get("labels", None)
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"decoder_attention_mask",
|
||||
"decoder_input_ids",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
else:
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||||
input_names.append("past_key_values")
|
||||
|
||||
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
|
||||
if "past_key_values" not in inputs:
|
||||
batch_size = inputs[next(iter(inputs))].shape[0]
|
||||
num_heads = model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||
empty_pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
|
||||
cache_length = 9
|
||||
cache_shape = (batch_size, num_heads, cache_length, head_dim)
|
||||
non_empty_pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
|
||||
inps = copy.deepcopy(inputs_to_test[0])
|
||||
|
||||
inputs_to_test[0]["past_key_values"] = empty_pkv
|
||||
|
||||
inps["past_key_values"] = non_empty_pkv
|
||||
inputs_to_test.append(inps)
|
||||
|
||||
past_mask = torch.ones(batch_size, cache_length, device=torch_device, dtype=torch.float)
|
||||
inputs_to_test[1]["attention_mask"] = torch.cat(
|
||||
(past_mask, inputs_to_test[1]["attention_mask"]), dim=1
|
||||
)
|
||||
|
||||
for inps in inputs_to_test:
|
||||
filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
else:
|
||||
input_names = [
|
||||
"attention_mask",
|
||||
"bbox",
|
||||
"input_features",
|
||||
"input_ids",
|
||||
"input_values",
|
||||
"pixel_values",
|
||||
"token_type_ids",
|
||||
"visual_feats",
|
||||
"visual_pos",
|
||||
]
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
labels = inputs.get("labels", None)
|
||||
start_positions = inputs.get("start_positions", None)
|
||||
end_positions = inputs.get("end_positions", None)
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
if start_positions is not None:
|
||||
input_names.append("start_positions")
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||||
input_names.append("past_key_values")
|
||||
|
||||
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
|
||||
if "past_key_values" not in inputs:
|
||||
batch_size = inputs[next(iter(inputs))].shape[0]
|
||||
num_heads = model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||
pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
|
||||
inputs["past_key_values"] = pkv
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
||||
def flatten_output(output):
|
||||
flatten = []
|
||||
for x in output:
|
||||
if isinstance(x, (tuple, list)):
|
||||
flatten += flatten_output(x)
|
||||
elif not isinstance(x, torch.Tensor):
|
||||
continue
|
||||
else:
|
||||
flatten.append(x)
|
||||
return flatten
|
||||
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
model_output = flatten_output(model_output)
|
||||
traced_output = flatten_output(traced_output)
|
||||
num_outputs = len(model_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
torch.allclose(model_output[i], traced_output[i]),
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
|
||||
Reference in New Issue
Block a user