From 709dc43239c66318f6f3c23123700192adf4fc2b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 1 Feb 2024 09:45:02 +0100 Subject: [PATCH] Fix symbolic_trace with kv cache (#28724) * fix symbolic_trace with kv cache * comment & better test --- src/transformers/utils/fx.py | 20 ++- tests/test_modeling_common.py | 244 ++++++++++++++++++---------------- 2 files changed, 144 insertions(+), 120 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index fd15fd9466..9f5c36a18a 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -765,7 +765,7 @@ class HFTracer(Tracer): ) def _generate_dummy_input( - self, model: PreTrainedModel, input_name: str, shape: List[int] + self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str] ) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored @@ -774,6 +774,11 @@ class HFTracer(Tracer): device = model.device inputs_dict = {} + # when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to + # rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162). + # After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing. + kv_cache_length = 5 + if input_name in ["labels", "start_positions", "end_positions"]: batch_size = shape[0] if model_class_name in [ @@ -883,7 +888,14 @@ class HFTracer(Tracer): # Generating big sequence length for audio inputs. seq_length = _generate_random_int(low=10000, high=20000) inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device) - elif "mask" in input_name or "ids" in input_name: + elif "mask" in input_name: + if "past_key_values" in input_names: + mask_shape = [shape[0], shape[1] + kv_cache_length] + else: + mask_shape = shape + + inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device) + elif "ids" in input_name: inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) elif "past_key_values" in input_name: if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE: @@ -893,7 +905,7 @@ class HFTracer(Tracer): num_heads = model.config.num_attention_heads head_dim = model.config.hidden_size // model.config.num_attention_heads - cache_shape = (shape[0], num_heads, 0, head_dim) + cache_shape = (shape[0], num_heads, kv_cache_length, head_dim) pkv = tuple( ( torch.rand(cache_shape, dtype=torch.float, device=device), @@ -1095,7 +1107,7 @@ class HFTracer(Tracer): if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith( ("_deserialize_graph_module", "_CodeOnlyModule") ): - inputs.update(self._generate_dummy_input(root, input_name, shape)) + inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names)) else: raise RuntimeError( f"Could not generate input named {input_name} for because root is not a" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fc904d0bc4..cefba1577a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1053,132 +1053,144 @@ class ModelTesterMixin: model.eval() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "attention_mask", - "decoder_attention_mask", - "decoder_input_ids", - "input_features", - "input_ids", - "input_values", - ] - if labels is not None: - input_names.append("labels") + # We may want to test several inputs (various shapes, etc.). + inputs_to_test = [inputs] - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = [ + "attention_mask", + "decoder_attention_mask", + "decoder_input_ids", + "input_features", + "input_ids", + "input_values", + ] + if labels is not None: + input_names.append("labels") + else: + input_names = [ + "attention_mask", + "bbox", + "input_features", + "input_ids", + "input_values", + "pixel_values", + "token_type_ids", + "visual_feats", + "visual_pos", + ] + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE: + input_names.append("past_key_values") + + # Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs. + if "past_key_values" not in inputs: + batch_size = inputs[next(iter(inputs))].shape[0] + num_heads = model.config.num_attention_heads + head_dim = model.config.hidden_size // model.config.num_attention_heads + + cache_shape = (batch_size, num_heads, 0, head_dim) + empty_pkv = tuple( + ( + torch.rand(cache_shape, dtype=torch.float, device=torch_device), + torch.rand(cache_shape, dtype=torch.float, device=torch_device), + ) + for i in range(model.config.num_hidden_layers) + ) + + cache_length = 9 + cache_shape = (batch_size, num_heads, cache_length, head_dim) + non_empty_pkv = tuple( + ( + torch.rand(cache_shape, dtype=torch.float, device=torch_device), + torch.rand(cache_shape, dtype=torch.float, device=torch_device), + ) + for i in range(model.config.num_hidden_layers) + ) + + inps = copy.deepcopy(inputs_to_test[0]) + + inputs_to_test[0]["past_key_values"] = empty_pkv + + inps["past_key_values"] = non_empty_pkv + inputs_to_test.append(inps) + + past_mask = torch.ones(batch_size, cache_length, device=torch_device, dtype=torch.float) + inputs_to_test[1]["attention_mask"] = torch.cat( + (past_mask, inputs_to_test[1]["attention_mask"]), dim=1 + ) + + for inps in inputs_to_test: + filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( + not hasattr(model.config, "problem_type") or model.config.problem_type is None + ): + model.config.problem_type = "single_label_classification" + + traced_model = symbolic_trace(model, input_names) + + with torch.no_grad(): + traced_output = traced_model(**filtered_inputs) model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = [ - "attention_mask", - "bbox", - "input_features", - "input_ids", - "input_values", - "pixel_values", - "token_type_ids", - "visual_feats", - "visual_pos", - ] + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE: - input_names.append("past_key_values") - - # Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs. - if "past_key_values" not in inputs: - batch_size = inputs[next(iter(inputs))].shape[0] - num_heads = model.config.num_attention_heads - head_dim = model.config.hidden_size // model.config.num_attention_heads - - cache_shape = (batch_size, num_heads, 0, head_dim) - pkv = tuple( - ( - torch.rand(cache_shape, dtype=torch.float, device=torch_device), - torch.rand(cache_shape, dtype=torch.float, device=torch_device), - ) - for i in range(model.config.num_hidden_layers) - ) - - inputs["past_key_values"] = pkv - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( - not hasattr(model.config, "problem_type") or model.config.problem_type is None - ): - model.config.problem_type = "single_label_classification" - - traced_model = symbolic_trace(model, input_names) - - with torch.no_grad(): - traced_output = traced_model(**filtered_inputs) - model_output = model(**filtered_inputs) - - except Exception as e: - self.fail(f"Couldn't trace module: {e}") - - def flatten_output(output): - flatten = [] - for x in output: - if isinstance(x, (tuple, list)): - flatten += flatten_output(x) - elif not isinstance(x, torch.Tensor): - continue - else: - flatten.append(x) - return flatten - - model_output = flatten_output(model_output) - traced_output = flatten_output(traced_output) - num_outputs = len(model_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], traced_output[i]), - f"traced {i}th output doesn't match model {i}th output for {model_class}", - ) - - # Test that the model can be serialized and restored properly - with tempfile.TemporaryDirectory() as tmp_dir_name: - pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") - try: - with open(pkl_file_name, "wb") as f: - pickle.dump(traced_model, f) - with open(pkl_file_name, "rb") as f: - loaded = pickle.load(f) - except Exception as e: - self.fail(f"Couldn't serialize / deserialize the traced model: {e}") - - loaded_output = loaded(**filtered_inputs) - loaded_output = flatten_output(loaded_output) + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) for i in range(num_outputs): self.assertTrue( - torch.allclose(model_output[i], loaded_output[i]), - f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", ) - # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. - # (Even with this call, there are still memory leak by ~0.04MB) - self.clear_torch_jit_class_registry() + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() def test_headmasking(self): if not self.test_head_masking: