From 7d9a33fb5cf40a87ff7fa9b4b8556b9bd4760461 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 14 Jan 2022 15:19:04 +0000 Subject: [PATCH] TF Bert inference - support `np.ndarray` optional arguments (#15074) * TF Bert inference - support np.ndarray optional arguments * apply np input tests to all TF architectures --- src/transformers/modeling_tf_utils.py | 7 +++++-- tests/test_modeling_tf_common.py | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e50f2a5989..ffd5045d18 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1941,16 +1941,19 @@ class TFSequenceSummary(tf.keras.layers.Layer): return output -def shape_list(tensor: tf.Tensor) -> List[int]: +def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]: """ Deal with dynamic shape in tensorflow cleanly. Args: - tensor (`tf.Tensor`): The tensor we want the shape of. + tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of. Returns: `List[int]`: The shape of the tensor as a list. """ + if isinstance(tensor, np.ndarray): + return list(tensor.shape) + dynamic = tf.shape(tensor) if tensor.shape == tf.TensorShape(None): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 074e8174d2..57c97455ae 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -846,7 +846,9 @@ class TFModelTesterMixin: inputs = self._prepare_for_class(inputs_dict, model_class) inputs_np = prepare_numpy_arrays(inputs) - model(inputs_np) + output_for_dict_input = model(inputs_np) + output_for_kw_input = model(**inputs_np) + self.assert_outputs_same(output_for_dict_input, output_for_kw_input) def test_resize_token_embeddings(self): if not self.test_resize_embeddings: