Optimize to_py_obj for python-native numeric lists and scalars (#36885)
* Optimize to_py_obj for python-native numeric lists and scalars * Fix bug that tuple is not converted to list * Try np.array for more robust type checking * Apply review and add tests for to_py_obj
This commit is contained in:
@@ -257,6 +257,18 @@ def to_py_obj(obj):
|
|||||||
"""
|
"""
|
||||||
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(obj, (int, float)):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, (dict, UserDict)):
|
||||||
|
return {k: to_py_obj(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
try:
|
||||||
|
arr = np.array(obj)
|
||||||
|
if np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating):
|
||||||
|
return arr.tolist()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return [to_py_obj(o) for o in obj]
|
||||||
|
|
||||||
framework_to_py_obj = {
|
framework_to_py_obj = {
|
||||||
"pt": lambda obj: obj.detach().cpu().tolist(),
|
"pt": lambda obj: obj.detach().cpu().tolist(),
|
||||||
@@ -265,11 +277,6 @@ def to_py_obj(obj):
|
|||||||
"np": lambda obj: obj.tolist(),
|
"np": lambda obj: obj.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(obj, (dict, UserDict)):
|
|
||||||
return {k: to_py_obj(v) for k, v in obj.items()}
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
return [to_py_obj(o) for o in obj]
|
|
||||||
|
|
||||||
# This gives us a smart order to test the frameworks with the corresponding tests.
|
# This gives us a smart order to test the frameworks with the corresponding tests.
|
||||||
framework_to_test_func = _get_frameworks_and_test_func(obj)
|
framework_to_test_func = _get_frameworks_and_test_func(obj)
|
||||||
for framework, test_func in framework_to_test_func.items():
|
for framework, test_func in framework_to_test_func.items():
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers.utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
reshape,
|
reshape,
|
||||||
squeeze,
|
squeeze,
|
||||||
|
to_py_obj,
|
||||||
transpose,
|
transpose,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -201,6 +202,77 @@ class GenericTester(unittest.TestCase):
|
|||||||
t = jnp.array(x)
|
t = jnp.array(x)
|
||||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
|
||||||
|
|
||||||
|
def test_to_py_obj_native(self):
|
||||||
|
self.assertTrue(to_py_obj(1) == 1)
|
||||||
|
self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3])
|
||||||
|
self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]])
|
||||||
|
|
||||||
|
def test_to_py_obj_numpy(self):
|
||||||
|
x1 = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
t1 = np.array(x1)
|
||||||
|
self.assertTrue(to_py_obj(t1) == x1)
|
||||||
|
|
||||||
|
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||||
|
t2 = np.array(x2)
|
||||||
|
self.assertTrue(to_py_obj(t2) == x2)
|
||||||
|
|
||||||
|
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_to_py_obj_torch(self):
|
||||||
|
x1 = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
t1 = torch.tensor(x1)
|
||||||
|
self.assertTrue(to_py_obj(t1) == x1)
|
||||||
|
|
||||||
|
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||||
|
t2 = torch.tensor(x2)
|
||||||
|
self.assertTrue(to_py_obj(t2) == x2)
|
||||||
|
|
||||||
|
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_to_py_obj_tf(self):
|
||||||
|
x1 = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
t1 = tf.constant(x1)
|
||||||
|
self.assertTrue(to_py_obj(t1) == x1)
|
||||||
|
|
||||||
|
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||||
|
t2 = tf.constant(x2)
|
||||||
|
self.assertTrue(to_py_obj(t2) == x2)
|
||||||
|
|
||||||
|
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
def test_to_py_obj_flax(self):
|
||||||
|
x1 = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
t1 = jnp.array(x1)
|
||||||
|
self.assertTrue(to_py_obj(t1) == x1)
|
||||||
|
|
||||||
|
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||||
|
t2 = jnp.array(x2)
|
||||||
|
self.assertTrue(to_py_obj(t2) == x2)
|
||||||
|
|
||||||
|
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_tf
|
||||||
|
@require_flax
|
||||||
|
def test_to_py_obj_mixed(self):
|
||||||
|
x1 = [[1], [2]]
|
||||||
|
t1 = np.array(x1)
|
||||||
|
|
||||||
|
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||||
|
t2 = torch.tensor(x2)
|
||||||
|
|
||||||
|
x3 = [1, 2, 3]
|
||||||
|
t3 = tf.constant(x3)
|
||||||
|
|
||||||
|
x4 = [[[1.0, 2.0]]]
|
||||||
|
t4 = jnp.array(x4)
|
||||||
|
|
||||||
|
mixed = [(t1, t2), (t3, t4)]
|
||||||
|
self.assertTrue(to_py_obj(mixed) == [[x1, x2], [x3, x4]])
|
||||||
|
|
||||||
|
|
||||||
class ValidationDecoratorTester(unittest.TestCase):
|
class ValidationDecoratorTester(unittest.TestCase):
|
||||||
def test_cases_no_warning(self):
|
def test_cases_no_warning(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user