From 10c774cf60b87e7c53c3a11c3126f1d115d8ea23 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 7 Sep 2022 16:22:09 +0200 Subject: [PATCH] remvoe `_create_and_check_torch_fx_tracing` in specific test files (#18667) * remvoe _create_and_check_torch_fx_tracing defined in specific model test files Co-authored-by: ydshieh --- .../models/donut/test_modeling_donut_swin.py | 101 +-------------- .../test_modeling_speech_to_text.py | 105 +-------------- tests/models/swin/test_modeling_swin.py | 101 +-------------- tests/models/xglm/test_modeling_xglm.py | 121 +----------------- 4 files changed, 4 insertions(+), 424 deletions(-) diff --git a/tests/models/donut/test_modeling_donut_swin.py b/tests/models/donut/test_modeling_donut_swin.py index f909d96188..a35a655059 100644 --- a/tests/models/donut/test_modeling_donut_swin.py +++ b/tests/models/donut/test_modeling_donut_swin.py @@ -16,14 +16,11 @@ import collections import inspect -import os -import pickle -import tempfile import unittest from transformers import DonutSwinConfig from transformers.testing_utils import require_torch, slow, torch_device -from transformers.utils import is_torch_available, is_torch_fx_available +from transformers.utils import is_torch_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor @@ -36,9 +33,6 @@ if is_torch_available(): from transformers import DonutSwinModel from transformers.models.donut.modeling_donut_swin import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST -if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace - class DonutSwinModelTester: def __init__( @@ -369,96 +363,3 @@ class DonutSwinModelTest(ModelTesterMixin, unittest.TestCase): [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): - if not is_torch_fx_available() or not self.fx_compatible: - return - - configs_no_init = _config_zero_init(config) # To be sure we have no Nan - configs_no_init.return_dict = False - - for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] - if labels is not None: - input_names.append("labels") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] - - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - - except RuntimeError as e: - self.fail(f"Couldn't trace module: {e}") - - def flatten_output(output): - flatten = [] - for x in output: - if isinstance(x, (tuple, list)): - flatten += flatten_output(x) - elif not isinstance(x, torch.Tensor): - continue - else: - flatten.append(x) - return flatten - - model_output = flatten_output(model_output) - traced_output = flatten_output(traced_output) - num_outputs = len(model_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], traced_output[i]), - f"traced {i}th output doesn't match model {i}th output for {model_class}", - ) - - # Test that the model can be serialized and restored properly - with tempfile.TemporaryDirectory() as tmp_dir_name: - pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") - try: - with open(pkl_file_name, "wb") as f: - pickle.dump(traced_model, f) - with open(pkl_file_name, "rb") as f: - loaded = pickle.load(f) - except Exception as e: - self.fail(f"Couldn't serialize / deserialize the traced model: {e}") - - loaded_output = loaded(**filtered_inputs) - loaded_output = flatten_output(loaded_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], loaded_output[i]), - f"serialized model {i}th output doesn't match model {i}th output for {model_class}", - ) diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index a1a625a9b4..f7645a2f06 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -17,7 +17,6 @@ import copy import inspect import os -import pickle import tempfile import unittest @@ -31,7 +30,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import cached_property, is_torch_fx_available +from transformers.utils import cached_property from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -44,9 +43,6 @@ if is_torch_available(): from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder -if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace - def prepare_speech_to_text_inputs_dict( config, @@ -720,105 +716,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes self.assertTrue(models_equal) - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): - if not is_torch_fx_available() or not self.fx_compatible: - return - - configs_no_init = _config_zero_init(config) # To be sure we have no Nan - configs_no_init.return_dict = False - - for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "input_ids", - "attention_mask", - "decoder_input_ids", - "decoder_attention_mask", - "input_features", - ] - if labels is not None: - input_names.append("labels") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"] - - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - - except RuntimeError as e: - self.fail(f"Couldn't trace module: {e}") - - def flatten_output(output): - flatten = [] - for x in output: - if isinstance(x, (tuple, list)): - flatten += flatten_output(x) - elif not isinstance(x, torch.Tensor): - continue - else: - flatten.append(x) - return flatten - - model_output = flatten_output(model_output) - traced_output = flatten_output(traced_output) - num_outputs = len(model_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], traced_output[i]), - f"traced {i}th output doesn't match model {i}th output for {model_class}", - ) - - # Test that the model can be serialized and restored properly - with tempfile.TemporaryDirectory() as tmp_dir_name: - pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") - try: - with open(pkl_file_name, "wb") as f: - pickle.dump(traced_model, f) - with open(pkl_file_name, "rb") as f: - loaded = pickle.load(f) - except Exception as e: - self.fail(f"Couldn't serialize / deserialize the traced model: {e}") - - loaded_output = loaded(**filtered_inputs) - loaded_output = flatten_output(loaded_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], loaded_output[i]), - f"serialized model {i}th output doesn't match model {i}th output for {model_class}", - ) - @require_torch @require_torchaudio diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 5e07efa2a3..9a5541d509 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -16,14 +16,11 @@ import collections import inspect -import os -import pickle -import tempfile import unittest from transformers import SwinConfig from transformers.testing_utils import require_torch, require_vision, slow, torch_device -from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available +from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor @@ -41,9 +38,6 @@ if is_vision_available(): from transformers import AutoFeatureExtractor -if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace - class SwinModelTester: def __init__( @@ -428,99 +422,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): - if not is_torch_fx_available() or not self.fx_compatible: - return - - configs_no_init = _config_zero_init(config) # To be sure we have no Nan - configs_no_init.return_dict = False - - for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] - if labels is not None: - input_names.append("labels") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] - - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - - except RuntimeError as e: - self.fail(f"Couldn't trace module: {e}") - - def flatten_output(output): - flatten = [] - for x in output: - if isinstance(x, (tuple, list)): - flatten += flatten_output(x) - elif not isinstance(x, torch.Tensor): - continue - else: - flatten.append(x) - return flatten - - model_output = flatten_output(model_output) - traced_output = flatten_output(traced_output) - num_outputs = len(model_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], traced_output[i]), - f"traced {i}th output doesn't match model {i}th output for {model_class}", - ) - - # Test that the model can be serialized and restored properly - with tempfile.TemporaryDirectory() as tmp_dir_name: - pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") - try: - with open(pkl_file_name, "wb") as f: - pickle.dump(traced_model, f) - with open(pkl_file_name, "rb") as f: - loaded = pickle.load(f) - except Exception as e: - self.fail(f"Couldn't serialize / deserialize the traced model: {e}") - - loaded_output = loaded(**filtered_inputs) - loaded_output = flatten_output(loaded_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], loaded_output[i]), - f"serialized model {i}th output doesn't match model {i}th output for {model_class}", - ) - @require_vision @require_torch diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index f4da499426..6d40ddab8e 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -15,24 +15,14 @@ import datetime import math -import os -import pickle -import tempfile import unittest from transformers import XGLMConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device -from transformers.utils import is_torch_fx_available from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ( - ModelTesterMixin, - _config_zero_init, - floats_tensor, - ids_tensor, - random_attention_mask, -) +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_torch_available(): @@ -40,9 +30,6 @@ if is_torch_available(): from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer -if is_torch_fx_available(): - from transformers.utils.fx import symbolic_trace - class XGLMModelTester: def __init__( @@ -350,112 +337,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs) - def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): - if not is_torch_fx_available() or not self.fx_compatible: - return - - configs_no_init = _config_zero_init(config) # To be sure we have no Nan - configs_no_init.return_dict = False - - for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) - - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - labels = inputs.get("labels", None) - input_names = [ - "input_ids", - "attention_mask", - "decoder_input_ids", - "decoder_attention_mask", - "input_features", - ] - if labels is not None: - input_names.append("labels") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - else: - input_names = [ - "input_ids", - "attention_mask", - "token_type_ids", - "pixel_values", - "bbox", - "input_features", - ] - - labels = inputs.get("labels", None) - start_positions = inputs.get("start_positions", None) - end_positions = inputs.get("end_positions", None) - if labels is not None: - input_names.append("labels") - if start_positions is not None: - input_names.append("start_positions") - if end_positions is not None: - input_names.append("end_positions") - - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} - input_names = list(filtered_inputs.keys()) - - model_output = model(**filtered_inputs) - - traced_model = symbolic_trace(model, input_names) - traced_output = traced_model(**filtered_inputs) - - except RuntimeError as e: - self.fail(f"Couldn't trace module: {e}") - - def flatten_output(output): - flatten = [] - for x in output: - if isinstance(x, (tuple, list)): - flatten += flatten_output(x) - elif not isinstance(x, torch.Tensor): - continue - else: - flatten.append(x) - return flatten - - model_output = flatten_output(model_output) - traced_output = flatten_output(traced_output) - num_outputs = len(model_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], traced_output[i]), - f"traced {i}th output doesn't match model {i}th output for {model_class}", - ) - - # Test that the model can be serialized and restored properly - with tempfile.TemporaryDirectory() as tmp_dir_name: - pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") - try: - with open(pkl_file_name, "wb") as f: - pickle.dump(traced_model, f) - with open(pkl_file_name, "rb") as f: - loaded = pickle.load(f) - except Exception as e: - self.fail(f"Couldn't serialize / deserialize the traced model: {e}") - - loaded_output = loaded(**filtered_inputs) - loaded_output = flatten_output(loaded_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], loaded_output[i]), - f"serialized model {i}th output doesn't match model {i}th output for {model_class}", - ) - @slow def test_batch_generation(self): model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")