Fixed torch.finfo issue with torch.fx (#20040)
This commit is contained in:
@@ -230,6 +230,15 @@ def torch_arange(*args, **kwargs):
|
||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
def torch_full(*args, **kwargs):
|
||||
args = list(args)
|
||||
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
|
||||
args[1] = 1 # Any value.
|
||||
kwargs_without_device = dict(kwargs)
|
||||
kwargs_without_device.pop("device", None)
|
||||
return torch.full(*args, **kwargs_without_device)
|
||||
|
||||
|
||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
if dim is None and axis is None:
|
||||
dim = 0
|
||||
@@ -509,6 +518,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.where: torch_where,
|
||||
torch.abs: torch_abs,
|
||||
torch.arange: torch_arange,
|
||||
torch.full: torch_full,
|
||||
torch.cat: torch_cat,
|
||||
torch.stack: torch_stack,
|
||||
torch.add: torch_add,
|
||||
@@ -552,12 +562,6 @@ class HFProxy(Proxy):
|
||||
def shape(self):
|
||||
return self.tracer.create_proxy("call_method", "size", (self,), {})
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if hasattr(self, "_metadata") and self._metadata is not None:
|
||||
return self._metadata.dtype
|
||||
return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||||
@@ -597,12 +601,15 @@ class HFAttribute(HFProxy):
|
||||
self.tracer = root.tracer
|
||||
self._node = None
|
||||
|
||||
if hasattr(self.root, "_metadata"):
|
||||
self.install_metadata(getattr(self.root._metadata, attr))
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
||||
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
@@ -663,7 +670,18 @@ class HFTracer(Tracer):
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
allow_insert_stateless_mods: bool = True
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
||||
_TORCH_METHODS_TO_PATCH = [
|
||||
"arange",
|
||||
"zeros",
|
||||
"ones",
|
||||
"full",
|
||||
"full_like",
|
||||
"eye",
|
||||
"empty",
|
||||
"tensor",
|
||||
"clamp",
|
||||
"finfo",
|
||||
]
|
||||
|
||||
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
|
||||
|
||||
@@ -737,6 +755,8 @@ class HFTracer(Tracer):
|
||||
"GPT2DoubleHeadsModel",
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
|
||||
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
|
||||
|
||||
@@ -835,17 +835,14 @@ class ModelTesterMixin:
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
if (
|
||||
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
|
||||
and not hasattr(model.config, "problem_type")
|
||||
or model.config.problem_type is None
|
||||
if isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) and (
|
||||
not hasattr(model.config, "problem_type") or model.config.problem_type is None
|
||||
):
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
@@ -871,20 +868,6 @@ class ModelTesterMixin:
|
||||
f"traced {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be TorchScripted
|
||||
try:
|
||||
scripted = torch.jit.script(traced_model)
|
||||
except Exception as e:
|
||||
self.fail(f"Could not TorchScript the traced model: {e}")
|
||||
scripted_output = scripted(**filtered_inputs)
|
||||
scripted_output = flatten_output(scripted_output)
|
||||
|
||||
for i in range(num_outputs):
|
||||
self.assertTrue(
|
||||
torch.allclose(model_output[i], scripted_output[i]),
|
||||
f"scripted {i}th output doesn't match model {i}th output for {model_class}",
|
||||
)
|
||||
|
||||
# Test that the model can be serialized and restored properly
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
|
||||
|
||||
Reference in New Issue
Block a user