TF Bert inference - support np.ndarray optional arguments (#15074)

* TF Bert inference - support np.ndarray optional arguments

* apply np input tests to all TF architectures
This commit is contained in:
Joao Gante
2022-01-14 15:19:04 +00:00
committed by GitHub
parent 4663c609b9
commit 7d9a33fb5c
2 changed files with 8 additions and 3 deletions

View File

@@ -1941,16 +1941,19 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
def shape_list(tensor: tf.Tensor) -> List[int]:
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (`tf.Tensor`): The tensor we want the shape of.
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns:
`List[int]`: The shape of the tensor as a list.
"""
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):

View File

@@ -846,7 +846,9 @@ class TFModelTesterMixin:
inputs = self._prepare_for_class(inputs_dict, model_class)
inputs_np = prepare_numpy_arrays(inputs)
model(inputs_np)
output_for_dict_input = model(inputs_np)
output_for_kw_input = model(**inputs_np)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_resize_token_embeddings(self):
if not self.test_resize_embeddings: