Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fdf58afb7 | ||
|
|
6530a989c6 | ||
|
|
bb98e7ce58 | ||
|
|
9fe3f585bb | ||
|
|
f8fec6b0ad |
2
setup.py
2
setup.py
@@ -429,7 +429,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.40.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.40.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.40.0"
|
||||
__version__ = "4.40.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -146,7 +146,18 @@ class EosTokenCriteria(StoppingCriteria):
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||
if input_ids.device.type == "mps":
|
||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||
is_done = (
|
||||
input_ids[:, -1]
|
||||
.tile(self.eos_token_id.shape[0], 1)
|
||||
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
|
||||
.sum(dim=0)
|
||||
.bool()
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
|
||||
return is_done
|
||||
|
||||
|
||||
|
||||
@@ -1021,8 +1021,11 @@ class CohereModel(CoherePreTrainedModel):
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
|
||||
@@ -1253,8 +1253,11 @@ class DbrxModel(DbrxPreTrainedModel):
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
|
||||
@@ -1007,8 +1007,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
|
||||
@@ -1099,8 +1099,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
|
||||
@@ -1082,8 +1082,11 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
elif attention_mask.dim() == 4:
|
||||
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
||||
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
||||
|
||||
@@ -84,12 +84,12 @@ if is_torch_neuroncore_available(check_device=False):
|
||||
if os.environ.get("TORCHELASTIC_RUN_ID"):
|
||||
if is_optimum_neuron_available():
|
||||
logger.info(
|
||||
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this "
|
||||
"Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
|
||||
"will fail otherwise."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform "
|
||||
"Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
|
||||
"training on AWS Trainium instances. More information here: "
|
||||
"https://github.com/huggingface/optimum-neuron"
|
||||
)
|
||||
|
||||
@@ -15,22 +15,28 @@
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import nn
|
||||
from torch.fx import Graph, GraphModule, Proxy, Tracer
|
||||
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
from torch.fx.proxy import ParameterProxy
|
||||
|
||||
from .. import PretrainedConfig, PreTrainedModel, logging
|
||||
from .. import logging
|
||||
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
|
||||
from ..modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from ..models.auto import get_values
|
||||
from ..models.auto.modeling_auto import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
||||
@@ -55,7 +61,7 @@ from ..models.auto.modeling_auto import (
|
||||
MODEL_MAPPING_NAMES,
|
||||
)
|
||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
from ..utils import (
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
TORCH_FX_REQUIRED_VERSION,
|
||||
get_torch_version,
|
||||
@@ -192,6 +198,8 @@ _SPECIAL_SUPPORTED_MODELS = [
|
||||
]
|
||||
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
|
||||
|
||||
_CURRENT_TRACER = None
|
||||
|
||||
|
||||
def torch_nn_embedding(self, input):
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
|
||||
@@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute):
|
||||
pass
|
||||
|
||||
|
||||
class HFCacheProxy(HFProxy):
|
||||
"""
|
||||
Proxy that represents an instance of `transformers.cache_utils.Cache`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
return ProxyableCache
|
||||
|
||||
|
||||
def create_wrapper(
|
||||
function: Callable,
|
||||
op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]],
|
||||
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
||||
) -> Callable:
|
||||
@functools.wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not is_fx_tracing():
|
||||
return function(*args, **kwargs)
|
||||
|
||||
found_proxies = []
|
||||
|
||||
def check_proxy(a):
|
||||
if isinstance(a, Proxy):
|
||||
found_proxies.append(a)
|
||||
|
||||
torch.fx.node.map_aggregate(args, check_proxy)
|
||||
torch.fx.node.map_aggregate(kwargs, check_proxy)
|
||||
|
||||
if len(found_proxies) > 0:
|
||||
tracer = found_proxies[0].tracer
|
||||
if op_type == "call_function":
|
||||
target = function
|
||||
elif op_type == "call_method":
|
||||
target = function.__name__
|
||||
elif op_type == "get_attr":
|
||||
target = function.__name__
|
||||
else:
|
||||
raise ValueError(f"op_type {op_type} not supported.")
|
||||
return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn)
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class HFProxyableClassMeta(type):
|
||||
"""
|
||||
Metaclass that creates a class with its main methods wrapped to be proxyable.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: Tuple[Type, ...],
|
||||
attrs: Dict[str, Any],
|
||||
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
||||
):
|
||||
cls = super().__new__(cls, name, bases, attrs)
|
||||
for attr_name in dir(cls):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
if attr_name == "__init__":
|
||||
op_type = "call_function"
|
||||
elif attr_name.startswith("__"):
|
||||
op_type = None
|
||||
elif inspect.ismethod(attr):
|
||||
op_type = "call_function"
|
||||
elif inspect.isfunction(attr):
|
||||
op_type = "call_method"
|
||||
else:
|
||||
op_type = None
|
||||
if op_type is not None:
|
||||
setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn))
|
||||
return cls
|
||||
|
||||
|
||||
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
|
||||
"""
|
||||
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
|
||||
"""
|
||||
wrapper = create_wrapper(target, "call_function")
|
||||
return wrapper, target
|
||||
|
||||
|
||||
def _proxies_to_metas(v):
|
||||
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
|
||||
if isinstance(v, MetaDeviceAttribute):
|
||||
@@ -712,25 +806,24 @@ def _proxies_to_metas(v):
|
||||
return v
|
||||
|
||||
|
||||
def _gen_constructor_wrapper(target):
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
proxy = None
|
||||
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||
global _CURRENT_TRACER
|
||||
if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
|
||||
return HFCacheProxy(n, _CURRENT_TRACER)
|
||||
|
||||
def check_has_proxy(v):
|
||||
if isinstance(v, Proxy):
|
||||
nonlocal proxy
|
||||
proxy = v
|
||||
|
||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||||
|
||||
if proxy is not None:
|
||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
else:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
|
||||
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
|
||||
ProxyableDynamicCache = HFProxyableClassMeta(
|
||||
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
|
||||
)
|
||||
ProxyableSinkCache = HFProxyableClassMeta(
|
||||
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
|
||||
)
|
||||
ProxyableStaticCache = HFProxyableClassMeta(
|
||||
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
|
||||
)
|
||||
|
||||
|
||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||
@@ -764,6 +857,13 @@ class HFTracer(Tracer):
|
||||
"finfo",
|
||||
"tril",
|
||||
]
|
||||
_CLASSES_TO_PATCH = {
|
||||
Cache: ProxyableCache,
|
||||
DynamicCache: ProxyableDynamicCache,
|
||||
SinkCache: ProxyableSinkCache,
|
||||
StaticCache: ProxyableStaticCache,
|
||||
}
|
||||
|
||||
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||||
|
||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||
@@ -776,7 +876,7 @@ class HFTracer(Tracer):
|
||||
)
|
||||
|
||||
def _generate_dummy_input(
|
||||
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
|
||||
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Generates dummy input for model inference recording."""
|
||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||
@@ -951,6 +1051,11 @@ class HFTracer(Tracer):
|
||||
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
|
||||
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
|
||||
|
||||
should_install_metadata = True
|
||||
|
||||
self._disable_module_getattr = True
|
||||
self._disable_call_module = True
|
||||
|
||||
if kind == "call_function":
|
||||
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
@@ -963,39 +1068,36 @@ class HFTracer(Tracer):
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
mod = self.root.get_submodule(target)
|
||||
mod_type = type(mod)
|
||||
if mod_type in _MANUAL_META_OVERRIDES:
|
||||
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
|
||||
else:
|
||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
mod = self.root.get_submodule(target)
|
||||
mod_type = type(mod)
|
||||
if mod_type in _MANUAL_META_OVERRIDES:
|
||||
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
|
||||
else:
|
||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||
elif kind == "get_attr":
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
attr_itr = self.root
|
||||
atoms = target.split(".")
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
if isinstance(attr_itr, torch.Tensor):
|
||||
meta_out = attr_itr.to(device="meta")
|
||||
else:
|
||||
meta_out = attr_itr
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
attr_itr = self.root
|
||||
atoms = target.split(".")
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
if isinstance(attr_itr, torch.Tensor):
|
||||
meta_out = attr_itr.to(device="meta")
|
||||
else:
|
||||
meta_out = attr_itr
|
||||
else:
|
||||
return rv
|
||||
should_install_metadata = False
|
||||
|
||||
if should_install_metadata:
|
||||
if not isinstance(rv, Proxy):
|
||||
raise ValueError("Don't support composite output yet")
|
||||
rv.install_metadata(meta_out)
|
||||
|
||||
if not isinstance(rv, Proxy):
|
||||
raise ValueError("Don't support composite output yet")
|
||||
rv.install_metadata(meta_out)
|
||||
except Exception as e:
|
||||
if _IS_IN_DEBUG_MODE:
|
||||
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
|
||||
|
||||
self._disable_module_getattr = False
|
||||
self._disable_call_module = False
|
||||
|
||||
return rv
|
||||
|
||||
# Replaced by .getattr from PyTorch 1.13
|
||||
@@ -1041,12 +1143,51 @@ class HFTracer(Tracer):
|
||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
if getattr(self, "_disable_call_module", False):
|
||||
return forward(*args, **kwargs)
|
||||
self.orig_forward = forward
|
||||
return super().call_module(m, forward, args, kwargs)
|
||||
|
||||
def proxy(self, node):
|
||||
return HFProxy(node, self)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
|
||||
# Patching torch functions
|
||||
self.patched_torch_methods = {
|
||||
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
self.orig_fns = set()
|
||||
|
||||
for name, (wrapper, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
self.orig_fns.add(orig)
|
||||
|
||||
# Patching classes
|
||||
patched = []
|
||||
module_of_model = inspect.getmodule(root)
|
||||
for name, mod in sys.modules.items():
|
||||
if module_of_model is not None and mod is not module_of_model:
|
||||
continue
|
||||
if not name.startswith("transformers"):
|
||||
continue
|
||||
for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items():
|
||||
for attr_name, attr in mod.__dict__.items():
|
||||
if attr is orig_cls:
|
||||
patched.append((mod, attr_name, orig_cls))
|
||||
setattr(mod, attr_name, patched_cls)
|
||||
|
||||
yield
|
||||
|
||||
# Restoring patched functions and classes.
|
||||
for name, (_, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
self.patched_torch_methods = {}
|
||||
self.orig_fns = set()
|
||||
|
||||
for mod, attr_name, orig_cls in patched:
|
||||
setattr(mod, attr_name, orig_cls)
|
||||
|
||||
def trace(
|
||||
self,
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
@@ -1125,28 +1266,25 @@ class HFTracer(Tracer):
|
||||
" transformers.PreTrainedModel."
|
||||
)
|
||||
|
||||
concrete_metas = {
|
||||
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
|
||||
for input_name, input_ in inputs.items()
|
||||
}
|
||||
def to_meta(value):
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.to("meta")
|
||||
return value
|
||||
|
||||
concrete_metas = pytree.tree_map(to_meta, inputs)
|
||||
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
|
||||
concrete_metas[f"**{param.name}"] = {}
|
||||
self.meta_args = concrete_metas
|
||||
self.patched_torch_methods = {
|
||||
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
self.orig_fns = set()
|
||||
|
||||
for name, (wrapper, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, wrapper)
|
||||
self.orig_fns.add(orig)
|
||||
|
||||
try:
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
finally:
|
||||
for name, (_, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
global _CURRENT_TRACER
|
||||
_CURRENT_TRACER = self
|
||||
with self.patch_for_tracing(root):
|
||||
try:
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
finally:
|
||||
_CURRENT_TRACER = None
|
||||
|
||||
# This is necessary because concrete args are added as input to the traced module since
|
||||
# https://github.com/pytorch/pytorch/pull/55888.
|
||||
@@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]):
|
||||
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
|
||||
|
||||
|
||||
def is_model_supported(model: PreTrainedModel):
|
||||
def is_model_supported(model: "PreTrainedModel"):
|
||||
return model.__class__.__name__ in _SUPPORTED_MODELS
|
||||
|
||||
|
||||
def check_if_model_is_supported(model: PreTrainedModel):
|
||||
def check_if_model_is_supported(model: "PreTrainedModel"):
|
||||
if not is_model_supported(model):
|
||||
supported_model_names = ", ".join(_SUPPORTED_MODELS)
|
||||
raise NotImplementedError(
|
||||
@@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel):
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
model: PreTrainedModel,
|
||||
model: "PreTrainedModel",
|
||||
input_names: Optional[List[str]] = None,
|
||||
disable_check: bool = False,
|
||||
tracer_cls: Type[HFTracer] = HFTracer,
|
||||
@@ -1307,6 +1445,18 @@ def symbolic_trace(
|
||||
if not disable_check:
|
||||
check_if_model_is_supported(model)
|
||||
|
||||
if "past_key_values" in input_names and not getattr(model.config, "use_cache", False):
|
||||
logger.warning(
|
||||
"`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to "
|
||||
"unexpected behavior."
|
||||
)
|
||||
if "past_key_values" not in input_names and getattr(model.config, "use_cache", False):
|
||||
logger.warning(
|
||||
"`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting "
|
||||
"model.config.use_cache = False."
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
# Tracing.
|
||||
tracer = tracer_cls()
|
||||
traced_graph = tracer.trace(model, concrete_args=concrete_args)
|
||||
|
||||
@@ -18,7 +18,6 @@ import gc
|
||||
import inspect
|
||||
import os
|
||||
import os.path
|
||||
import pickle
|
||||
import random
|
||||
import re
|
||||
import tempfile
|
||||
@@ -1279,26 +1278,6 @@ class ModelTesterMixin:
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
try:
|
||||
with open(pkl_file_name, "wb") as f:
|
||||
pickle.dump(traced_model, f)
|
||||
with open(pkl_file_name, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
|
||||
|
||||
loaded_output = loaded(**filtered_inputs)
|
||||
loaded_output = flatten_output(loaded_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], loaded_output[i]),
|
||||
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
Reference in New Issue
Block a user