Fix TF input for np.ndarray (#9294)
* Fix input for np.ndarray" * add a test * add a test * Add a test * Apply style * Fix test
This commit is contained in:
@@ -324,7 +324,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
signature.pop("kwargs", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
|
||||
if "inputs" in kwargs["kwargs_call"]:
|
||||
warnings.warn(
|
||||
|
||||
Reference in New Issue
Block a user