Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -1030,16 +1030,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
@@ -1049,16 +1040,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
loss = None
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
self.config.pad_token_id is not None or logits_shape[0] == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
|
||||
Reference in New Issue
Block a user