Fix TF input for np.ndarray (#9294)
* Fix input for np.ndarray" * add a test * add a test * Add a test * Apply style * Fix test
This commit is contained in:
@@ -324,7 +324,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
signature.pop("kwargs", None)
|
signature.pop("kwargs", None)
|
||||||
parameter_names = list(signature.keys())
|
parameter_names = list(signature.keys())
|
||||||
output = {}
|
output = {}
|
||||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
|
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||||
|
|
||||||
if "inputs" in kwargs["kwargs_call"]:
|
if "inputs" in kwargs["kwargs_call"]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|||||||
@@ -805,6 +805,27 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
|
||||||
|
def test_numpy_arrays_inputs(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def prepare_numpy_arrays(inputs_dict):
|
||||||
|
inputs_np_dict = {}
|
||||||
|
for k, v in inputs_dict.items():
|
||||||
|
if tf.is_tensor(v):
|
||||||
|
inputs_np_dict[k] = v.numpy()
|
||||||
|
else:
|
||||||
|
inputs_np_dict[k] = np.array(k)
|
||||||
|
|
||||||
|
return inputs_np_dict
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
inputs_np = prepare_numpy_arrays(inputs)
|
||||||
|
|
||||||
|
model(inputs_np)
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
if not self.test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user