Skip test_export_to_onnx for LongT5 if torch < 1.11 (#19122)
* Skip if torch < 1.11 * fix quality * fix import * fix typo * fix condition * fix condition * fix condition * fix quality * fix condition Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -39,6 +39,7 @@ if is_torch_available():
|
|||||||
LongT5Model,
|
LongT5Model,
|
||||||
)
|
)
|
||||||
from transformers.models.longt5.modeling_longt5 import LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.longt5.modeling_longt5 import LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.pytorch_utils import is_torch_less_than_1_11
|
||||||
|
|
||||||
|
|
||||||
class LongT5ModelTester:
|
class LongT5ModelTester:
|
||||||
@@ -584,6 +585,10 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
model = LongT5Model.from_pretrained(model_name)
|
model = LongT5Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_torch_available() or is_torch_less_than_1_11,
|
||||||
|
"Test failed with torch < 1.11 with an exception in a C++ file.",
|
||||||
|
)
|
||||||
@slow
|
@slow
|
||||||
def test_export_to_onnx(self):
|
def test_export_to_onnx(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user