[tests] remove test_export_to_onnx (#36241)
This commit is contained in:
@@ -709,20 +709,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
model = SwitchTransformersModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Test has a segmentation fault on torch 1.8.0")
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = SwitchTransformersModel(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
|
||||
f"{tmpdirname}/switch_transformers_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
input_names=["input_ids", "decoder_input_ids"],
|
||||
)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user