Optimize to_py_obj for python-native numeric lists and scalars (#36885)

* Optimize to_py_obj for python-native numeric lists and scalars

* Fix bug that tuple is not converted to list

* Try np.array for more robust type checking

* Apply review and add tests for to_py_obj
This commit is contained in:
Sungyoon Jeong
2025-03-27 22:16:46 +09:00
committed by GitHub
parent 0e56fb69a2
commit d1eafe8d4e
2 changed files with 84 additions and 5 deletions

View File

@@ -28,6 +28,7 @@ from transformers.utils import (
is_torch_available,
reshape,
squeeze,
to_py_obj,
transpose,
)
@@ -201,6 +202,77 @@ class GenericTester(unittest.TestCase):
t = jnp.array(x)
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
def test_to_py_obj_native(self):
self.assertTrue(to_py_obj(1) == 1)
self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3])
self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]])
def test_to_py_obj_numpy(self):
x1 = [[1, 2, 3], [4, 5, 6]]
t1 = np.array(x1)
self.assertTrue(to_py_obj(t1) == x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = np.array(x2)
self.assertTrue(to_py_obj(t2) == x2)
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
@require_torch
def test_to_py_obj_torch(self):
x1 = [[1, 2, 3], [4, 5, 6]]
t1 = torch.tensor(x1)
self.assertTrue(to_py_obj(t1) == x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = torch.tensor(x2)
self.assertTrue(to_py_obj(t2) == x2)
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
@require_tf
def test_to_py_obj_tf(self):
x1 = [[1, 2, 3], [4, 5, 6]]
t1 = tf.constant(x1)
self.assertTrue(to_py_obj(t1) == x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = tf.constant(x2)
self.assertTrue(to_py_obj(t2) == x2)
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
@require_flax
def test_to_py_obj_flax(self):
x1 = [[1, 2, 3], [4, 5, 6]]
t1 = jnp.array(x1)
self.assertTrue(to_py_obj(t1) == x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = jnp.array(x2)
self.assertTrue(to_py_obj(t2) == x2)
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
@require_torch
@require_tf
@require_flax
def test_to_py_obj_mixed(self):
x1 = [[1], [2]]
t1 = np.array(x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = torch.tensor(x2)
x3 = [1, 2, 3]
t3 = tf.constant(x3)
x4 = [[[1.0, 2.0]]]
t4 = jnp.array(x4)
mixed = [(t1, t2), (t3, t4)]
self.assertTrue(to_py_obj(mixed) == [[x1, x2], [x3, x4]])
class ValidationDecoratorTester(unittest.TestCase):
def test_cases_no_warning(self):