Fix TF input for np.ndarray (#9294)
* Fix input for np.ndarray" * add a test * add a test * Add a test * Apply style * Fix test
This commit is contained in:
@@ -324,7 +324,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
signature.pop("kwargs", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict)
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
|
||||
if "inputs" in kwargs["kwargs_call"]:
|
||||
warnings.warn(
|
||||
|
||||
@@ -805,6 +805,27 @@ class TFModelTesterMixin:
|
||||
|
||||
model(inputs)
|
||||
|
||||
def test_numpy_arrays_inputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def prepare_numpy_arrays(inputs_dict):
|
||||
inputs_np_dict = {}
|
||||
for k, v in inputs_dict.items():
|
||||
if tf.is_tensor(v):
|
||||
inputs_np_dict[k] = v.numpy()
|
||||
else:
|
||||
inputs_np_dict[k] = np.array(k)
|
||||
|
||||
return inputs_np_dict
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs_np = prepare_numpy_arrays(inputs)
|
||||
|
||||
model(inputs_np)
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user