Adds use_repr to model_addition_debugger_context (#37984)
* Adds use_repr to model_addition_debugger_context * Updating docs for use_repr option
This commit is contained in:
@@ -21,6 +21,8 @@ from contextlib import contextmanager, redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
from .utils import is_torch_available
|
||||
@@ -65,64 +67,94 @@ def _dtensor_repr(x):
|
||||
return "DTensor(non-rank0)"
|
||||
|
||||
|
||||
def _serialize_io(value):
|
||||
def _serialize_tensor_like_io(
|
||||
value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Converts Tensors and DTensors to a JSON-serializable dictionary representation.
|
||||
|
||||
Args:
|
||||
value: Any Python object, often including torch Tensors, lists, dicts, etc.
|
||||
debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
|
||||
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
|
||||
`value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
|
||||
SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
|
||||
path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
|
||||
tensor value if `use_repr=False`.
|
||||
|
||||
Returns:
|
||||
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
|
||||
"""
|
||||
torch.set_printoptions(sci_mode=True)
|
||||
|
||||
if use_repr:
|
||||
value_out = _repr_to_list(value)
|
||||
elif path_to_value:
|
||||
if not path_to_value.endswith(".safetensors"):
|
||||
path_to_value += ".safetensors"
|
||||
|
||||
filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value
|
||||
save_file({"data": value.contiguous().detach().cpu()}, filepath)
|
||||
value_out = f"./{path_to_value}"
|
||||
else:
|
||||
raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.")
|
||||
|
||||
out = {
|
||||
"shape": repr(value.shape),
|
||||
"dtype": repr(value.dtype),
|
||||
"value": value_out,
|
||||
}
|
||||
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
|
||||
out.update(
|
||||
{
|
||||
"mean": _sanitize_repr_for_diff(repr(value.mean())),
|
||||
"std": _sanitize_repr_for_diff(repr(value.std())),
|
||||
"min": _sanitize_repr_for_diff(repr(value.min())),
|
||||
"max": _sanitize_repr_for_diff(repr(value.max())),
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None):
|
||||
"""
|
||||
Recursively build a JSON-serializable Python structure from `value`.
|
||||
Tensors and DTensors become sanitized repr strings.
|
||||
Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
|
||||
relative paths are recorded in the returned Python structure.
|
||||
Lists/tuples/dicts are recursed into.
|
||||
All memory addresses are replaced with a stable placeholder.
|
||||
|
||||
Args:
|
||||
value: Any Python object, often including torch Tensors, lists, dicts, etc.
|
||||
debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
|
||||
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
|
||||
`value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
|
||||
files and store the relative path to that file in the `value` property.
|
||||
path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
|
||||
tensor value if `use_repr=False`.
|
||||
|
||||
Returns:
|
||||
A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
|
||||
"""
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_serialize_io(v) for v in value]
|
||||
return [
|
||||
_serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}")
|
||||
for i, v in enumerate(value)
|
||||
]
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _serialize_io(v) for k, v in value.items()}
|
||||
return {
|
||||
k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}")
|
||||
for k, v in value.items()
|
||||
}
|
||||
|
||||
if hasattr(value, "_local_tensor"):
|
||||
# DTensor-like handling, just use local tensor attribute
|
||||
torch.set_printoptions(sci_mode=True)
|
||||
val_repr = _repr_to_list(value)
|
||||
out = {
|
||||
"shape": repr(value._local_tensor.shape),
|
||||
"dtype": repr(value._local_tensor.dtype),
|
||||
"value": val_repr,
|
||||
}
|
||||
if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}:
|
||||
value = value._local_tensor.clone()
|
||||
out.update(
|
||||
{
|
||||
"mean": _sanitize_repr_for_diff(repr(value.mean())),
|
||||
"std": _sanitize_repr_for_diff(repr(value.std())),
|
||||
"min": _sanitize_repr_for_diff(repr(value.min())),
|
||||
"max": _sanitize_repr_for_diff(repr(value.max())),
|
||||
}
|
||||
)
|
||||
return out
|
||||
return _serialize_tensor_like_io(
|
||||
value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value
|
||||
)
|
||||
|
||||
if isinstance(value, torch.Tensor):
|
||||
torch.set_printoptions(sci_mode=True)
|
||||
val_repr = _repr_to_list(value)
|
||||
out = {
|
||||
"shape": repr(value.shape),
|
||||
"dtype": repr(value.dtype),
|
||||
"value": val_repr,
|
||||
}
|
||||
if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
|
||||
out.update(
|
||||
{
|
||||
"mean": _sanitize_repr_for_diff(repr(value.mean())),
|
||||
"std": _sanitize_repr_for_diff(repr(value.std())),
|
||||
"min": _sanitize_repr_for_diff(repr(value.min())),
|
||||
"max": _sanitize_repr_for_diff(repr(value.max())),
|
||||
}
|
||||
)
|
||||
return out
|
||||
return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value)
|
||||
|
||||
return _sanitize_repr_for_diff(repr(value))
|
||||
|
||||
@@ -199,7 +231,7 @@ def log_model_debug_trace(debug_path, model):
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}")
|
||||
raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
|
||||
else:
|
||||
base = model._debugger_module_dump_name + "_debug_tree"
|
||||
|
||||
@@ -240,6 +272,7 @@ def _attach_debugger_logic(
|
||||
model,
|
||||
debug_path: Optional[str] = ".",
|
||||
do_prune_layers: Optional[bool] = True,
|
||||
use_repr: bool = True,
|
||||
):
|
||||
"""
|
||||
Attaches a debugging wrapper to every module in the model.
|
||||
@@ -250,6 +283,9 @@ def _attach_debugger_logic(
|
||||
model (`PreTrainedModel`, `nn.Module`): Model to wrap.
|
||||
debug_path (`str`): Optional directory to dump debug JSON files.
|
||||
do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
|
||||
use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
|
||||
`value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
|
||||
files and store the relative path to that file in the `value` property.
|
||||
"""
|
||||
class_name = model.__class__.__name__
|
||||
|
||||
@@ -258,6 +294,12 @@ def _attach_debugger_logic(
|
||||
model._debugger_model_call_stack = []
|
||||
model._debugger_module_dump_name = class_name # used for final JSON filename
|
||||
|
||||
if debug_path:
|
||||
try:
|
||||
os.makedirs(debug_path, exist_ok=True)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
|
||||
|
||||
def wrap_forward(module, full_path):
|
||||
orig_forward = module.forward
|
||||
|
||||
@@ -268,7 +310,12 @@ def _attach_debugger_logic(
|
||||
dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
|
||||
node = {
|
||||
"module_path": full_path,
|
||||
"inputs": _serialize_io(dict_inputs),
|
||||
"inputs": _serialize_io(
|
||||
dict_inputs,
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{full_path}_inputs",
|
||||
),
|
||||
"outputs": None,
|
||||
"children": [],
|
||||
}
|
||||
@@ -280,7 +327,12 @@ def _attach_debugger_logic(
|
||||
if sum(1 for _ in module.named_children()) > 0:
|
||||
node["outputs"] = None
|
||||
else:
|
||||
node["outputs"] = _serialize_io(out)
|
||||
node["outputs"] = _serialize_io(
|
||||
out,
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{full_path}_outputs",
|
||||
)
|
||||
|
||||
finished = model._debugger_model_call_stack.pop()
|
||||
# prune empty vertices here as well (mostly empty children nodes)
|
||||
@@ -307,7 +359,12 @@ def _attach_debugger_logic(
|
||||
if _is_rank_zero():
|
||||
top_node = {
|
||||
"module_path": f"{class_name} (top-level)",
|
||||
"inputs": _serialize_io({"args": inps, "kwargs": kws}),
|
||||
"inputs": _serialize_io(
|
||||
{"args": inps, "kwargs": kws},
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{class_name}_inputs",
|
||||
),
|
||||
"outputs": None,
|
||||
"children": [],
|
||||
}
|
||||
@@ -315,7 +372,12 @@ def _attach_debugger_logic(
|
||||
|
||||
out = real_top_forward(*inps, **kws)
|
||||
if _is_rank_zero() and model._debugger_model_call_stack:
|
||||
top_node["outputs"] = _serialize_io(out)
|
||||
top_node["outputs"] = _serialize_io(
|
||||
out,
|
||||
debug_path=debug_path,
|
||||
use_repr=use_repr,
|
||||
path_to_value=f"{class_name}_outputs",
|
||||
)
|
||||
finished = model._debugger_model_call_stack.pop()
|
||||
model._call_tree["inputs"] = finished["inputs"]
|
||||
model._call_tree["outputs"] = finished["outputs"]
|
||||
@@ -335,11 +397,21 @@ def _attach_debugger_logic(
|
||||
|
||||
@requires(backends=("torch",))
|
||||
@contextmanager
|
||||
def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_prune_layers: Optional[bool] = True):
|
||||
def model_addition_debugger_context(
|
||||
model,
|
||||
debug_path: Optional[str] = None,
|
||||
do_prune_layers: Optional[bool] = True,
|
||||
use_repr: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
# Model addition debugger - context manager for model adders
|
||||
This context manager is a power user tool intended for model adders.
|
||||
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
|
||||
|
||||
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
|
||||
If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
|
||||
strings. If `use_repr=False`, the full tensors will be stored in spearate SafeTensors files and the JSON file will
|
||||
provide a relative path to that file.
|
||||
|
||||
To note, this context manager enforces `torch.no_grad()`.
|
||||
|
||||
## Usage
|
||||
@@ -348,10 +420,10 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import LlavaProcessor, LlavaForConditionalGeneration
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context
|
||||
|
||||
torch.random.manual_seed(673)
|
||||
|
||||
# load pretrained model and processor
|
||||
@@ -376,7 +448,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_
|
||||
"""
|
||||
orig_forwards = {m: m.forward for _, m in model.named_modules()}
|
||||
orig_forwards[model] = model.forward
|
||||
_attach_debugger_logic(model, debug_path, do_prune_layers)
|
||||
_attach_debugger_logic(model, debug_path, do_prune_layers, use_repr)
|
||||
try:
|
||||
yield model
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user