Fx support for multiple model architectures (#17393)
* Support for Bart and LayoutLM, and partial support for XLNet * Support for mbart * A lot of new models supported * Support for other models * LayoutLM fix * Use strings instead of classes
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Swin model. """
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
|
||||
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -45,14 +44,6 @@ if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||
setattr(configs_no_init, key, 1e-10)
|
||||
return configs_no_init
|
||||
|
||||
|
||||
class SwinModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||
if labels is not None:
|
||||
input_names.append("labels")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
input_names.append("end_positions")
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = filtered_inputs.keys()
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user