From 4fbcf8ea496bece21b9442d71280257b9953152a Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Fri, 8 Jan 2021 14:23:29 +0100 Subject: [PATCH] Fix TF input for np.ndarray (#9294) * Fix input for np.ndarray" * add a test * add a test * Add a test * Apply style * Fix test --- src/transformers/modeling_tf_utils.py | 2 +- tests/test_modeling_tf_common.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 694984ac02..d013bc910e 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -324,7 +324,7 @@ def input_processing(func, config, input_ids, **kwargs): signature.pop("kwargs", None) parameter_names = list(signature.keys()) output = {} - allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict) + allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray) if "inputs" in kwargs["kwargs_call"]: warnings.warn( diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 702b531b6c..1ef360cf89 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -805,6 +805,27 @@ class TFModelTesterMixin: model(inputs) + def test_numpy_arrays_inputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def prepare_numpy_arrays(inputs_dict): + inputs_np_dict = {} + for k, v in inputs_dict.items(): + if tf.is_tensor(v): + inputs_np_dict[k] = v.numpy() + else: + inputs_np_dict[k] = np.array(k) + + return inputs_np_dict + + for model_class in self.all_model_classes: + model = model_class(config) + + inputs = self._prepare_for_class(inputs_dict, model_class) + inputs_np = prepare_numpy_arrays(inputs) + + model(inputs_np) + def test_resize_token_embeddings(self): if not self.test_resize_embeddings: return