From 9080607b2c652cfb6dbc84f0310a0ab9ecc8e8fc Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Nov 2022 16:14:44 +0100 Subject: [PATCH] Fixed torch.finfo issue with torch.fx (#20040) --- src/transformers/utils/fx.py | 36 +++++++++++++++++++++++++++-------- tests/test_modeling_common.py | 23 +++------------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index af58b4cc96..37893f18fd 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -230,6 +230,15 @@ def torch_arange(*args, **kwargs): return torch.empty((end - start) // step, dtype=dtype, device="meta") +def torch_full(*args, **kwargs): + args = list(args) + if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"): + args[1] = 1 # Any value. + kwargs_without_device = dict(kwargs) + kwargs_without_device.pop("device", None) + return torch.full(*args, **kwargs_without_device) + + def torch_cat(tensors, dim=None, axis=None, *, out=None): if dim is None and axis is None: dim = 0 @@ -509,6 +518,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.where: torch_where, torch.abs: torch_abs, torch.arange: torch_arange, + torch.full: torch_full, torch.cat: torch_cat, torch.stack: torch_stack, torch.add: torch_add, @@ -552,12 +562,6 @@ class HFProxy(Proxy): def shape(self): return self.tracer.create_proxy("call_method", "size", (self,), {}) - @property - def dtype(self): - if hasattr(self, "_metadata") and self._metadata is not None: - return self._metadata.dtype - return self.tracer.create_proxy("call_function", builtins.getattr, (self, "dtype"), {}) - @property def device(self): # Hack so we can track when devices are used. During meta-tensor propagation, @@ -597,12 +601,15 @@ class HFAttribute(HFProxy): self.tracer = root.tracer self._node = None + if hasattr(self.root, "_metadata"): + self.install_metadata(getattr(self.root._metadata, attr)) + @property def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): @@ -663,7 +670,18 @@ class HFTracer(Tracer): # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True allow_insert_stateless_mods: bool = True - _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"] + _TORCH_METHODS_TO_PATCH = [ + "arange", + "zeros", + "ones", + "full", + "full_like", + "eye", + "empty", + "tensor", + "clamp", + "finfo", + ] def __init__(self, autowrap_modules=(math,), autowrap_functions=()): @@ -737,6 +755,8 @@ class HFTracer(Tracer): "GPT2DoubleHeadsModel", ]: inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) + elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]: + inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device) else: raise NotImplementedError( f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet." diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6a0d3b7dc9..f5452b5043 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -835,17 +835,14 @@ class ModelTesterMixin: filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - - if ( - isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) - and not hasattr(model.config, "problem_type") - or model.config.problem_type is None + if isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values())) and ( + not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) + model_output = model(**filtered_inputs) except Exception as e: self.fail(f"Couldn't trace module: {e}") @@ -871,20 +868,6 @@ class ModelTesterMixin: f"traced {i}th output doesn't match model {i}th output for {model_class}", ) - # Test that the model can be TorchScripted - try: - scripted = torch.jit.script(traced_model) - except Exception as e: - self.fail(f"Could not TorchScript the traced model: {e}") - scripted_output = scripted(**filtered_inputs) - scripted_output = flatten_output(scripted_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], scripted_output[i]), - f"scripted {i}th output doesn't match model {i}th output for {model_class}", - ) - # Test that the model can be serialized and restored properly with tempfile.TemporaryDirectory() as tmp_dir_name: pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")