Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -151,6 +151,16 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
_onnx_available = (
|
||||
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
|
||||
)
|
||||
try:
|
||||
_onxx_version = importlib_metadata.version("onnx")
|
||||
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||
@@ -230,6 +240,10 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user