Unpin numba (#23162)

* fix for ragged list

* unpin numba

* make style

* np.object -> object

* propagate changes to tokenizer as well

* np.long -> "long"

* revert tokenization changes

* check with tokenization changes

* list/tuple logic

* catch numpy

* catch else case

* clean up

* up

* better check

* trigger ci

* Empty commit to trigger CI
This commit is contained in:
Sanchit Gandhi
2023-05-31 14:59:30 +01:00
committed by GitHub
parent d99f11e898
commit 8f915c450d
6 changed files with 23 additions and 10 deletions

View File

@@ -705,7 +705,15 @@ class BatchEncoding(UserDict):
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray
def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array
# Do the tensor conversion in batch