Fix TF input for np.ndarray (#9294)

* Fix input for np.ndarray"

* add a test

* add a test

* Add a test

* Apply style

* Fix test
This commit is contained in:
Julien Plu
2021-01-08 14:23:29 +01:00
committed by GitHub
parent e34e45536f
commit 4fbcf8ea49
2 changed files with 22 additions and 1 deletions

View File

@@ -324,7 +324,7 @@ def input_processing(func, config, input_ids, **kwargs):
signature.pop("kwargs", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
if "inputs" in kwargs["kwargs_call"]:
warnings.warn(