Optimize to_py_obj for python-native numeric lists and scalars (#36885)
* Optimize to_py_obj for python-native numeric lists and scalars * Fix bug that tuple is not converted to list * Try np.array for more robust type checking * Apply review and add tests for to_py_obj
This commit is contained in:
@@ -257,6 +257,18 @@ def to_py_obj(obj):
|
||||
"""
|
||||
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
||||
"""
|
||||
if isinstance(obj, (int, float)):
|
||||
return obj
|
||||
elif isinstance(obj, (dict, UserDict)):
|
||||
return {k: to_py_obj(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
try:
|
||||
arr = np.array(obj)
|
||||
if np.issubdtype(arr.dtype, np.integer) or np.issubdtype(arr.dtype, np.floating):
|
||||
return arr.tolist()
|
||||
except Exception:
|
||||
pass
|
||||
return [to_py_obj(o) for o in obj]
|
||||
|
||||
framework_to_py_obj = {
|
||||
"pt": lambda obj: obj.detach().cpu().tolist(),
|
||||
@@ -265,11 +277,6 @@ def to_py_obj(obj):
|
||||
"np": lambda obj: obj.tolist(),
|
||||
}
|
||||
|
||||
if isinstance(obj, (dict, UserDict)):
|
||||
return {k: to_py_obj(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [to_py_obj(o) for o in obj]
|
||||
|
||||
# This gives us a smart order to test the frameworks with the corresponding tests.
|
||||
framework_to_test_func = _get_frameworks_and_test_func(obj)
|
||||
for framework, test_func in framework_to_test_func.items():
|
||||
|
||||
Reference in New Issue
Block a user