diff --git a/docs/source/en/internal/model_debugging_utils.md b/docs/source/en/internal/model_debugging_utils.md index ab11a45b34..6d30668c63 100644 --- a/docs/source/en/internal/model_debugging_utils.md +++ b/docs/source/en/internal/model_debugging_utils.md @@ -28,7 +28,7 @@ Most of those are only useful if you are adding new models in the library. This context manager is a power user tool intended for model adders. It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. -To note, this context manager enforces `torch.inference_mode()`. +To note, this context manager enforces `torch.no_grad()`. ### Rationale @@ -43,6 +43,7 @@ import torch from PIL import Image import requests from transformers import LlavaProcessor, LlavaForConditionalGeneration +from transformers.model_debugging_utils import model_addition_debugger_context torch.random.manual_seed(673) # load pretrained model and processor @@ -60,12 +61,153 @@ prompt = "Describe this image." inputs = processor(text=prompt, images=random_image, return_tensors="pt") # call forward method (not .generate!) -with model_addition_debugger_context(model, "optional_path_to_your_output_file.json"): +with model_addition_debugger_context( + model, + debug_path="optional_path_to_your_directory", + do_prune_layers=False # This will output ALL the layers of a model. + ): output = model.forward(**inputs) ``` -[[autodoc]] model_addition_debugger +### Reading results + +The debugger generates two files from the forward call, both with the same base name, +but ending either with `_SUMMARY.json` or with `_FULL_TENSORS.json`. + +The first one will contain a summary of each module's _input_ and _output_ tensor values and shapes. + +```json +{ + "module_path": "MolmoForConditionalGeneration", + "inputs": { + "args": [], + "kwargs": { + "input_ids": { + "shape": "torch.Size([1, 589])", + "dtype": "torch.int64" + }, + "attention_mask": { + "shape": "torch.Size([1, 589])", + "dtype": "torch.int64" + }, + "pixel_values": { + "shape": "torch.Size([1, 5, 576, 588])", + "dtype": "torch.float32", + "mean": "tensor(-8.9514e-01, device='cuda:0')", + "std": "tensor(9.2586e-01, device='cuda:0')", + "min": "tensor(-1.7923e+00, device='cuda:0')", + "max": "tensor(1.8899e+00, device='cuda:0')" + } + }, + "children": [ + { + "module_path": "MolmoForConditionalGeneration.language_model.model.embed_tokens", + "inputs": { + "args": [ + { + "shape": "torch.Size([1, 589])", + "dtype": "torch.int64" + } + ] + }, + "outputs": { + "shape": "torch.Size([1, 589, 3584])", + "dtype": "torch.float32", + "mean": "tensor(6.5460e-06, device='cuda:0')", + "std": "tensor(2.3807e-02, device='cuda:0')", + "min": "tensor(-3.3398e-01, device='cuda:0')", + "max": "tensor(3.9453e-01, device='cuda:0')" + } + }, + { + "module_path": "MolmoForConditionalGeneration.vision_tower", + "inputs": { + "args": [ + { + "shape": "torch.Size([5, 1, 576, 588])", + "dtype": "torch.float32", + "mean": "tensor(-8.9514e-01, device='cuda:0')", + "std": "tensor(9.2586e-01, device='cuda:0')", + "min": "tensor(-1.7923e+00, device='cuda:0')", + "max": "tensor(1.8899e+00, device='cuda:0')" + } + ], + "kwargs": { + "output_hidden_states": "True" + } + }, + "children": [ + { ... and so on +``` + +The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful +for comparing two files. +```json + "pixel_values": { + "shape": "torch.Size([1, 5, 576, 588])", + "dtype": "torch.float32", + "value": [ + "tensor([[[[-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " ...,", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00]],", + "", + " [[-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " ...,", + " [-1.4857e+00, -1.4820e+00, -1.2100e+00, ..., -6.0979e-01, -5.9650e-01, -3.8527e-01],", + " [-1.6755e+00, -1.7221e+00, -1.4518e+00, ..., -7.5577e-01, -7.4658e-01, -5.5592e-01],", + " [-7.9957e-01, -8.2162e-01, -5.7014e-01, ..., -1.3689e+00, -1.3169e+00, -1.0678e+00]],", + "", + " [[-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " ...,", + " [-3.0322e-01, -5.0645e-01, -5.8436e-01, ..., -6.2439e-01, -7.9160e-01, -8.1188e-01],", + " [-4.4921e-01, -6.5653e-01, -7.2656e-01, ..., -3.4702e-01, -5.2146e-01, -5.1326e-01],", + " [-3.4702e-01, -5.3647e-01, -5.4170e-01, ..., -1.0915e+00, -1.1968e+00, -1.0252e+00]],", + "", + " [[-1.1207e+00, -1.2718e+00, -1.0678e+00, ..., 1.2013e-01, -1.3126e-01, -1.7197e-01],", + " [-6.9738e-01, -9.1166e-01, -8.5454e-01, ..., -5.5050e-02, -2.8134e-01, -4.2793e-01],", + " [-3.4702e-01, -5.5148e-01, -5.8436e-01, ..., 1.9312e-01, -8.6235e-02, -2.1463e-01],", + " ...,", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00]],", + "", + " [[-1.0039e+00, -9.5669e-01, -6.5546e-01, ..., -1.4711e+00, -1.4219e+00, -1.1389e+00],", + " [-1.0039e+00, -9.5669e-01, -6.5546e-01, ..., -1.7193e+00, -1.6771e+00, -1.4091e+00],", + " [-1.6317e+00, -1.6020e+00, -1.2669e+00, ..., -1.2667e+00, -1.2268e+00, -8.9720e-01],", + " ...,", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00],", + " [-1.7923e+00, -1.7521e+00, -1.4802e+00, ..., -1.7923e+00, -1.7521e+00, -1.4802e+00]]]], device='cuda:0')" + ], + "mean": "tensor(-8.9514e-01, device='cuda:0')", + "std": "tensor(9.2586e-01, device='cuda:0')", + "min": "tensor(-1.7923e+00, device='cuda:0')", + "max": "tensor(1.8899e+00, device='cuda:0')" + }, +``` + +### Comparing between implementations + +Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong. + + +![download-icon](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/files_difference_debugging.png) + + +### Limitations and scope + +This feature will only work for torch-based models, and would require more work and case-by-case approach for say `jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be traced once instead of reran N times with breakpoints. + +If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N layers. [[autodoc]] model_addition_debugger_context diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 75f03315f9..b7ba86b64f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -344,7 +344,6 @@ except OptionalDependencyNotAvailable: _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: _import_structure["model_debugging_utils"] = [ - "model_addition_debugger", "model_addition_debugger_context", ] _import_structure["activations"] = [] @@ -910,7 +909,6 @@ if TYPE_CHECKING: convert_and_export_with_cache, ) from .model_debugging_utils import ( - model_addition_debugger, model_addition_debugger_context, ) from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update diff --git a/src/transformers/model_debugging_utils.py b/src/transformers/model_debugging_utils.py index c419c2c273..009ac0c6b2 100644 --- a/src/transformers/model_debugging_utils.py +++ b/src/transformers/model_debugging_utils.py @@ -17,7 +17,8 @@ import functools import json import os import re -from contextlib import contextmanager +from contextlib import contextmanager, redirect_stdout +from io import StringIO from typing import Optional from transformers.utils.import_utils import requires @@ -28,9 +29,7 @@ from .utils import is_torch_available if is_torch_available(): import torch import torch.distributed.tensor - from torch import nn - from .modeling_utils import PreTrainedModel from .utils import logging @@ -87,21 +86,64 @@ def _serialize_io(value): if hasattr(value, "_local_tensor"): # DTensor-like handling, just use local tensor attribute - return { + torch.set_printoptions(sci_mode=True) + val_repr = _repr_to_list(value) + out = { "shape": repr(value._local_tensor.shape), "dtype": repr(value._local_tensor.dtype), - "value": _sanitize_repr_for_diff(repr(value)), + "value": val_repr, } + if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}: + value = value._local_tensor.clone() + out.update( + { + "mean": _sanitize_repr_for_diff(repr(value.mean())), + "std": _sanitize_repr_for_diff(repr(value.std())), + "min": _sanitize_repr_for_diff(repr(value.min())), + "max": _sanitize_repr_for_diff(repr(value.max())), + } + ) + return out if isinstance(value, torch.Tensor): - # standard PyTorch Tensor - # return also the shape of such - return {"shape": repr(value.shape), "dtype": repr(value.dtype), "value": _sanitize_repr_for_diff(repr(value))} + torch.set_printoptions(sci_mode=True) + val_repr = _repr_to_list(value) + out = { + "shape": repr(value.shape), + "dtype": repr(value.dtype), + "value": val_repr, + } + if value.dtype in {torch.float16, torch.float32, torch.bfloat16}: + out.update( + { + "mean": _sanitize_repr_for_diff(repr(value.mean())), + "std": _sanitize_repr_for_diff(repr(value.std())), + "min": _sanitize_repr_for_diff(repr(value.min())), + "max": _sanitize_repr_for_diff(repr(value.max())), + } + ) + return out - # fallback for everything else (bool, int, float, None, or custom class) return _sanitize_repr_for_diff(repr(value)) +def _repr_to_list(value: torch.Tensor): + """ + Converts a tensor into a sanitized multi-line string representation. + + Args: + value (`torch.Tensor`): The tensor to represent. + + Returns: + `List[str]`: List of string lines representing the tensor. + """ + torch.set_printoptions(sci_mode=True, linewidth=120) + with StringIO() as buf, redirect_stdout(buf): + print(value) # to redirected stdout to avoid line splits + raw = buf.getvalue() + return _sanitize_repr_for_diff(raw).splitlines() + + def prune_outputs_if_children(node): # if there are children, remove this node's "outputs" # so we only see outputs at the leaf level @@ -111,22 +153,106 @@ def prune_outputs_if_children(node): prune_outputs_if_children(child) +LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$") # should be generic enough, ends with a number + + +def is_layer_block(node): + """ + Checks whether a node represents a layer block with submodules. + + Args: + node (`dict`): A node from the call tree. + + Returns: + `bool`: Whether the node is a layer block. + """ + match = LAYER_SUFFIX_RE.match(node.get("module_path", "")) + if not match or not node.get("children"): + return False + number = match.group(2) + return any(f".{number}." in child.get("module_path", "") for child in node["children"]) + + +def prune_intermediate_layers(node): + """ + Recursively removes intermediate layers from the tree to improve readability. + Keeps at least the first and last layers if many consecutive layers are present. + + Args: + node (`dict`): The root or subnode to prune recursively. + """ + if not node.get("children"): + return + layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)] + + if len(layer_blocks) > 2: + to_remove = [i for i, _ in layer_blocks[1:-1]] + node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove] + + for child in node["children"]: + prune_intermediate_layers(child) + + def log_model_debug_trace(debug_path, model): if debug_path: try: - os.makedirs(debug_path, exist_ok=False) - output_path = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree.json") + os.makedirs(debug_path, exist_ok=True) + base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree") except Exception as e: raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}") else: - output_path = model._debugger_module_dump_name + "_debug_tree.json" - logger.info(f"Writing model trace at {output_path}") - with open(output_path, "w") as outfile: - prune_outputs_if_children(model._call_tree) - json.dump(model._call_tree, outfile, indent=2) + base = model._debugger_module_dump_name + "_debug_tree" + + logger.info(f"Writing model trace at {base}.json") + full_path = base + "_FULL_TENSORS.json" + summary_path = base + "_SUMMARY.json" + + prune_outputs_if_children(model._call_tree) + + with open(full_path, "w") as f: + json.dump(model._call_tree, f, indent=2) + + # summary-only version for readability - traversing the tree again #TODO optimize? + def strip_values(node): + def clean(val): + if isinstance(val, dict): + val.pop("value", None) + for v in val.values(): + clean(v) + elif isinstance(val, list): + for item in val: + clean(item) + + clean(node.get("inputs", {})) + clean(node.get("outputs", {})) + + for child in node.get("children", []): + strip_values(child) + + tree_copy = json.loads(json.dumps(model._call_tree)) # deep copy + strip_values(tree_copy) + + with open(summary_path, "w") as f: + json.dump(tree_copy, f, indent=2) -def _attach_debugger_logic(model, class_name, debug_path: str): +def _attach_debugger_logic( + model, + debug_path: Optional[str] = ".", + do_prune_layers: Optional[bool] = True, +): + """ + Attaches a debugging wrapper to every module in the model. + + This records structured inputs and outputs during the forward pass into a call tree. + + Args: + model (`PreTrainedModel`, `nn.Module`): Model to wrap. + debug_path (`str`): Optional directory to dump debug JSON files. + do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers. + """ + class_name = model.__class__.__name__ + # Prepare data structures on the model object model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []} model._debugger_model_call_stack = [] @@ -147,7 +273,7 @@ def _attach_debugger_logic(model, class_name, debug_path: str): "children": [], } model._debugger_model_call_stack.append(node) - with torch.inference_mode(): + with torch.no_grad(): out = orig_forward(*inps, **kws) if _is_rank_zero(): @@ -188,7 +314,6 @@ def _attach_debugger_logic(model, class_name, debug_path: str): model._debugger_model_call_stack.append(top_node) out = real_top_forward(*inps, **kws) - if _is_rank_zero() and model._debugger_model_call_stack: top_node["outputs"] = _serialize_io(out) finished = model._debugger_model_call_stack.pop() @@ -198,98 +323,24 @@ def _attach_debugger_logic(model, class_name, debug_path: str): # prune empty stuff for visibility [model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]] + # prune layers that are not 0 or last + if do_prune_layers: + prune_intermediate_layers(model._call_tree) + # Write final JSON trace here + log_model_debug_trace(debug_path=debug_path, model=model) return out model.forward = top_wrapped_forward - # Final hook for writing JSON on forward-end - def final_hook(_, inputs, outputs): - if _is_rank_zero() and model._debugger_model_call_stack: - finished = model._debugger_model_call_stack.pop() - model._call_tree["inputs"] = finished["inputs"] - model._call_tree["outputs"] = finished["outputs"] - model._call_tree["children"] = finished["children"] - - if _is_rank_zero(): - log_model_debug_trace(debug_path=debug_path, model=model) - - model.register_forward_hook(final_hook) - # Optionally also for a couple possible hooks that have specific names. It should be just one. - # This means modules that are not typically called "forward" within the model. But we should not need to recurse - # through them. - possible_model_calls = ["language_model", "model"] - for model_call in possible_model_calls: - this_model_call = getattr(model, model_call, None) - if this_model_call and isinstance(this_model_call, (nn.Module, PreTrainedModel)): - this_model_call.register_forward_hook(final_hook) - break # exit the loop after finding one (unsure, but should be just one call.) - - -@requires(backends=("torch",)) -def model_addition_debugger(cls): - """ - # Model addition debugger - a model adder tracer - This decorator is a power user tool intended for model adders. - It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. - To note, this decorator enforces `torch.inference_mode()`. - ## Usage - - add decorator to your model class - ```python - from ...modeling_utils import model_addition_debugger - - @model_addition_debugger - class MyModel(nn.Module) # Can inherit from PreTrainedModel too - # ... nothing else changes - ``` - Then, in a separate script (example is for Llava) - - ```python - import torch - from PIL import Image - import requests - from transformers import LlavaProcessor, LlavaForConditionalGeneration - torch.random.manual_seed(673) - - # load pretrained model and processor - model_id = "llava-hf/llava-1.5-7b-hf" - processor = LlavaProcessor.from_pretrained(model_id) - model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True) - - # create random image input - random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy()) - - # prompt - prompt = "Describe this image." - - # process inputs - inputs = processor(text=prompt, images=random_image, return_tensors="pt") - - # call forward method (not .generate!) - with torch.no_grad(): - output = model.forward(**inputs) - ``` - - """ - orig_init = cls.__init__ - - @functools.wraps(cls.__init__) - def wrapped_init(self, *args, **kwargs): - orig_init(self, *args, **kwargs) - _attach_debugger_logic(self, cls.__name__) - - cls.__init__ = wrapped_init - return cls - @requires(backends=("torch",)) @contextmanager -def model_addition_debugger_context(model, debug_path: Optional[str] = None): +def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_prune_layers: Optional[bool] = True): """ # Model addition debugger - context manager for model adders This context manager is a power user tool intended for model adders. It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. - To note, this context manager enforces `torch.inference_mode()`. + To note, this context manager enforces `torch.no_grad()`. ## Usage @@ -300,6 +351,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None): from PIL import Image import requests from transformers import LlavaProcessor, LlavaForConditionalGeneration + from transformers.model_debugging_utils import model_addition_debugger_context torch.random.manual_seed(673) # load pretrained model and processor @@ -317,13 +369,16 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None): inputs = processor(text=prompt, images=random_image, return_tensors="pt") # call forward method (not .generate!) - with model_addition_debugger_context(model): + with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False): output = model.forward(**inputs) ``` """ - _attach_debugger_logic(model, model.__class__.__name__, debug_path) + orig_forwards = {m: m.forward for _, m in model.named_modules()} + orig_forwards[model] = model.forward + _attach_debugger_logic(model, debug_path, do_prune_layers) try: yield model finally: - pass + for module_instance, forward_method in orig_forwards.items(): + module_instance.forward = forward_method diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 55c592082c..be4e47b9fc 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -538,10 +538,6 @@ def convert_and_export_with_cache(*args, **kwargs): requires_backends(convert_and_export_with_cache, ["torch"]) -def model_addition_debugger(*args, **kwargs): - requires_backends(model_addition_debugger, ["torch"]) - - def model_addition_debugger_context(*args, **kwargs): requires_backends(model_addition_debugger_context, ["torch"]) diff --git a/tests/utils/test_model_debugging_utils.py b/tests/utils/test_model_debugging_utils.py new file mode 100644 index 0000000000..30419f8b7e --- /dev/null +++ b/tests/utils/test_model_debugging_utils.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import os +import tempfile +import unittest +from pathlib import Path + +from transformers import is_torch_available +from transformers.model_debugging_utils import model_addition_debugger_context + + +if is_torch_available(): + import torch + from torch import nn + + class ToyModel(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(10, 4) + self.linear_1 = nn.Linear(4, 8) + self.linear_2 = nn.Linear(8, 2) + self.act = nn.ReLU() + + def forward(self, input_ids: str): + hidden_states = self.embed(input_ids).mean(dim=1) + hidden_states = self.act(self.linear_1(hidden_states)) + return self.linear_2(hidden_states) + + class TestModelAdditionDebugger(unittest.TestCase): + def setUp(self): + self.model = ToyModel() + self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))} + + def tearDown(self): + gc.collect() + + def test_debugger_outputs(self): + with tempfile.TemporaryDirectory() as tmpdir: + with model_addition_debugger_context(self.model, debug_path=str(tmpdir)): + _ = self.model.forward(**self.inputs) + + base = f"{self.model.__class__.__name__}_debug_tree" + summary = Path(os.path.join(tmpdir, f"{base}_SUMMARY.json")) + full = Path(os.path.join(tmpdir, f"{base}_FULL_TENSORS.json")) + self.assertTrue(os.path.isfile(summary) and os.path.isfile(full)) + data = json.loads(summary.read_text()) + self.assertTrue({"module_path", "inputs", "children"} <= data.keys()) + self.assertTrue(data["children"]) + + class ToyLayer(nn.Module): + def __init__(self, layer_index): + super().__init__() + self.layer_index = layer_index + self.layer_operation = nn.Linear(4, 4) + + def forward(self, hidden_states): + return self.layer_operation(hidden_states) + + class ToyModelWithLayers(nn.Module): + def __init__(self): + super().__init__() + self.input_proj = nn.Linear(4, 4) + self.layers = nn.ModuleList([ToyLayer(layer_index) for layer_index in range(6)]) + self.output_proj = nn.Linear(4, 2) + + def forward(self, x): + x = self.input_proj(x) + for layer in self.layers: + x = layer(x) + return self.output_proj(x) + + class TestModelWithLayers(unittest.TestCase): + def setUp(self): + self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))} + self.model_with_layers = ToyModelWithLayers() + self.dense_input = {"x": torch.randn(1, 4)} + + def tearDown(self): + gc.collect() + + def test_layer_pruning_behavior(self): + # No pruning: expect all 6 layers + with tempfile.TemporaryDirectory() as tmpdir: + with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=False): + _ = self.model_with_layers(**self.dense_input) + + summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json") + with open(summary_path) as f: + data = json.load(f) + self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"}) + for layer_index in range(6): + self.assertEqual( + data["children"][layer_index + 1]["module_path"], + f"ToyModelWithLayers.layers.{int(layer_index)}", + ) + + # Pruning: expect only 2 layers (0 and 5) + with tempfile.TemporaryDirectory() as tmpdir: + with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=True): + _ = self.model_with_layers(**self.dense_input) + + summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json") + with open(summary_path) as f: + data = json.load(f) + self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"}) + self.assertEqual(data["children"][1]["module_path"], "ToyModelWithLayers.layers.0") + self.assertEqual(data["children"][2]["module_path"], "ToyModelWithLayers.layers.5")