Model debugger upgrades (#37391)

* debugging improvements

* add debugging details

* add more debugging details

* debug more

* clean up layers + output

* add summary json file

* cleanup

* copies 👀

* remove hooks + add documentation

* draft a small test, why not

* respect the format (respect it)

* fixup imports

* nit

* add tests and configurable pruning of layers
This commit is contained in:
Pablo Montalvo
2025-04-18 16:45:54 +02:00
committed by GitHub
parent e5ac23081e
commit 4afd3f4820
5 changed files with 425 additions and 112 deletions

View File

@@ -17,7 +17,8 @@ import functools
import json
import os
import re
from contextlib import contextmanager
from contextlib import contextmanager, redirect_stdout
from io import StringIO
from typing import Optional
from transformers.utils.import_utils import requires
@@ -28,9 +29,7 @@ from .utils import is_torch_available
if is_torch_available():
import torch
import torch.distributed.tensor
from torch import nn
from .modeling_utils import PreTrainedModel
from .utils import logging
@@ -87,21 +86,64 @@ def _serialize_io(value):
if hasattr(value, "_local_tensor"):
# DTensor-like handling, just use local tensor attribute
return {
torch.set_printoptions(sci_mode=True)
val_repr = _repr_to_list(value)
out = {
"shape": repr(value._local_tensor.shape),
"dtype": repr(value._local_tensor.dtype),
"value": _sanitize_repr_for_diff(repr(value)),
"value": val_repr,
}
if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}:
value = value._local_tensor.clone()
out.update(
{
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
}
)
return out
if isinstance(value, torch.Tensor):
# standard PyTorch Tensor
# return also the shape of such
return {"shape": repr(value.shape), "dtype": repr(value.dtype), "value": _sanitize_repr_for_diff(repr(value))}
torch.set_printoptions(sci_mode=True)
val_repr = _repr_to_list(value)
out = {
"shape": repr(value.shape),
"dtype": repr(value.dtype),
"value": val_repr,
}
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
out.update(
{
"mean": _sanitize_repr_for_diff(repr(value.mean())),
"std": _sanitize_repr_for_diff(repr(value.std())),
"min": _sanitize_repr_for_diff(repr(value.min())),
"max": _sanitize_repr_for_diff(repr(value.max())),
}
)
return out
# fallback for everything else (bool, int, float, None, or custom class)
return _sanitize_repr_for_diff(repr(value))
def _repr_to_list(value: torch.Tensor):
"""
Converts a tensor into a sanitized multi-line string representation.
Args:
value (`torch.Tensor`): The tensor to represent.
Returns:
`List[str]`: List of string lines representing the tensor.
"""
torch.set_printoptions(sci_mode=True, linewidth=120)
with StringIO() as buf, redirect_stdout(buf):
print(value) # to redirected stdout to avoid line splits
raw = buf.getvalue()
return _sanitize_repr_for_diff(raw).splitlines()
def prune_outputs_if_children(node):
# if there are children, remove this node's "outputs"
# so we only see outputs at the leaf level
@@ -111,22 +153,106 @@ def prune_outputs_if_children(node):
prune_outputs_if_children(child)
LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$") # should be generic enough, ends with a number
def is_layer_block(node):
"""
Checks whether a node represents a layer block with submodules.
Args:
node (`dict`): A node from the call tree.
Returns:
`bool`: Whether the node is a layer block.
"""
match = LAYER_SUFFIX_RE.match(node.get("module_path", ""))
if not match or not node.get("children"):
return False
number = match.group(2)
return any(f".{number}." in child.get("module_path", "") for child in node["children"])
def prune_intermediate_layers(node):
"""
Recursively removes intermediate layers from the tree to improve readability.
Keeps at least the first and last layers if many consecutive layers are present.
Args:
node (`dict`): The root or subnode to prune recursively.
"""
if not node.get("children"):
return
layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)]
if len(layer_blocks) > 2:
to_remove = [i for i, _ in layer_blocks[1:-1]]
node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove]
for child in node["children"]:
prune_intermediate_layers(child)
def log_model_debug_trace(debug_path, model):
if debug_path:
try:
os.makedirs(debug_path, exist_ok=False)
output_path = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree.json")
os.makedirs(debug_path, exist_ok=True)
base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
except Exception as e:
raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}")
else:
output_path = model._debugger_module_dump_name + "_debug_tree.json"
logger.info(f"Writing model trace at {output_path}")
with open(output_path, "w") as outfile:
prune_outputs_if_children(model._call_tree)
json.dump(model._call_tree, outfile, indent=2)
base = model._debugger_module_dump_name + "_debug_tree"
logger.info(f"Writing model trace at {base}.json")
full_path = base + "_FULL_TENSORS.json"
summary_path = base + "_SUMMARY.json"
prune_outputs_if_children(model._call_tree)
with open(full_path, "w") as f:
json.dump(model._call_tree, f, indent=2)
# summary-only version for readability - traversing the tree again #TODO optimize?
def strip_values(node):
def clean(val):
if isinstance(val, dict):
val.pop("value", None)
for v in val.values():
clean(v)
elif isinstance(val, list):
for item in val:
clean(item)
clean(node.get("inputs", {}))
clean(node.get("outputs", {}))
for child in node.get("children", []):
strip_values(child)
tree_copy = json.loads(json.dumps(model._call_tree)) # deep copy
strip_values(tree_copy)
with open(summary_path, "w") as f:
json.dump(tree_copy, f, indent=2)
def _attach_debugger_logic(model, class_name, debug_path: str):
def _attach_debugger_logic(
model,
debug_path: Optional[str] = ".",
do_prune_layers: Optional[bool] = True,
):
"""
Attaches a debugging wrapper to every module in the model.
This records structured inputs and outputs during the forward pass into a call tree.
Args:
model (`PreTrainedModel`, `nn.Module`): Model to wrap.
debug_path (`str`): Optional directory to dump debug JSON files.
do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
"""
class_name = model.__class__.__name__
# Prepare data structures on the model object
model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []}
model._debugger_model_call_stack = []
@@ -147,7 +273,7 @@ def _attach_debugger_logic(model, class_name, debug_path: str):
"children": [],
}
model._debugger_model_call_stack.append(node)
with torch.inference_mode():
with torch.no_grad():
out = orig_forward(*inps, **kws)
if _is_rank_zero():
@@ -188,7 +314,6 @@ def _attach_debugger_logic(model, class_name, debug_path: str):
model._debugger_model_call_stack.append(top_node)
out = real_top_forward(*inps, **kws)
if _is_rank_zero() and model._debugger_model_call_stack:
top_node["outputs"] = _serialize_io(out)
finished = model._debugger_model_call_stack.pop()
@@ -198,98 +323,24 @@ def _attach_debugger_logic(model, class_name, debug_path: str):
# prune empty stuff for visibility
[model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]]
# prune layers that are not 0 or last
if do_prune_layers:
prune_intermediate_layers(model._call_tree)
# Write final JSON trace here
log_model_debug_trace(debug_path=debug_path, model=model)
return out
model.forward = top_wrapped_forward
# Final hook for writing JSON on forward-end
def final_hook(_, inputs, outputs):
if _is_rank_zero() and model._debugger_model_call_stack:
finished = model._debugger_model_call_stack.pop()
model._call_tree["inputs"] = finished["inputs"]
model._call_tree["outputs"] = finished["outputs"]
model._call_tree["children"] = finished["children"]
if _is_rank_zero():
log_model_debug_trace(debug_path=debug_path, model=model)
model.register_forward_hook(final_hook)
# Optionally also for a couple possible hooks that have specific names. It should be just one.
# This means modules that are not typically called "forward" within the model. But we should not need to recurse
# through them.
possible_model_calls = ["language_model", "model"]
for model_call in possible_model_calls:
this_model_call = getattr(model, model_call, None)
if this_model_call and isinstance(this_model_call, (nn.Module, PreTrainedModel)):
this_model_call.register_forward_hook(final_hook)
break # exit the loop after finding one (unsure, but should be just one call.)
@requires(backends=("torch",))
def model_addition_debugger(cls):
"""
# Model addition debugger - a model adder tracer
This decorator is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this decorator enforces `torch.inference_mode()`.
## Usage
add decorator to your model class
```python
from ...modeling_utils import model_addition_debugger
@model_addition_debugger
class MyModel(nn.Module) # Can inherit from PreTrainedModel too
# ... nothing else changes
```
Then, in a separate script (example is for Llava)
```python
import torch
from PIL import Image
import requests
from transformers import LlavaProcessor, LlavaForConditionalGeneration
torch.random.manual_seed(673)
# load pretrained model and processor
model_id = "llava-hf/llava-1.5-7b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)
# create random image input
random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
# prompt
prompt = "<image>Describe this image."
# process inputs
inputs = processor(text=prompt, images=random_image, return_tensors="pt")
# call forward method (not .generate!)
with torch.no_grad():
output = model.forward(**inputs)
```
"""
orig_init = cls.__init__
@functools.wraps(cls.__init__)
def wrapped_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
_attach_debugger_logic(self, cls.__name__)
cls.__init__ = wrapped_init
return cls
@requires(backends=("torch",))
@contextmanager
def model_addition_debugger_context(model, debug_path: Optional[str] = None):
def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_prune_layers: Optional[bool] = True):
"""
# Model addition debugger - context manager for model adders
This context manager is a power user tool intended for model adders.
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
To note, this context manager enforces `torch.inference_mode()`.
To note, this context manager enforces `torch.no_grad()`.
## Usage
@@ -300,6 +351,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None):
from PIL import Image
import requests
from transformers import LlavaProcessor, LlavaForConditionalGeneration
from transformers.model_debugging_utils import model_addition_debugger_context
torch.random.manual_seed(673)
# load pretrained model and processor
@@ -317,13 +369,16 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None):
inputs = processor(text=prompt, images=random_image, return_tensors="pt")
# call forward method (not .generate!)
with model_addition_debugger_context(model):
with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False):
output = model.forward(**inputs)
```
"""
_attach_debugger_logic(model, model.__class__.__name__, debug_path)
orig_forwards = {m: m.forward for _, m in model.named_modules()}
orig_forwards[model] = model.forward
_attach_debugger_logic(model, debug_path, do_prune_layers)
try:
yield model
finally:
pass
for module_instance, forward_method in orig_forwards.items():
module_instance.forward = forward_method