Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -151,6 +151,16 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
_onnx_available = (
|
||||
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
|
||||
)
|
||||
try:
|
||||
_onxx_version = importlib_metadata.version("onnx")
|
||||
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||
@@ -230,6 +240,10 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
@@ -1030,16 +1030,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
@@ -1049,16 +1040,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
loss = None
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
self.config.pad_token_id is not None or logits_shape[0] == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
|
||||
@@ -28,6 +28,7 @@ from .file_utils import (
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_onnx_available,
|
||||
is_pandas_available,
|
||||
is_scatter_available,
|
||||
is_sentencepiece_available,
|
||||
@@ -160,6 +161,13 @@ def require_git_lfs(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_onnx(test_case):
|
||||
if not is_onnx_available():
|
||||
return unittest.skip("test requires ONNX")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch.
|
||||
|
||||
Reference in New Issue
Block a user