Unpin numba (#23162)
* fix for ragged list * unpin numba * make style * np.object -> object * propagate changes to tokenizer as well * np.long -> "long" * revert tokenization changes * check with tokenization changes * list/tuple logic * catch numpy * catch else case * clean up * up * better check * trigger ci * Empty commit to trigger CI
This commit is contained in:
@@ -705,7 +705,15 @@ class BatchEncoding(UserDict):
|
||||
as_tensor = jnp.array
|
||||
is_tensor = is_jax_tensor
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
|
||||
def as_tensor(value, dtype=None):
|
||||
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
|
||||
value_lens = [len(val) for val in value]
|
||||
if len(set(value_lens)) > 1 and dtype is None:
|
||||
# we have a ragged list so handle explicitly
|
||||
value = as_tensor([np.asarray(val) for val in value], dtype=object)
|
||||
return np.asarray(value, dtype=dtype)
|
||||
|
||||
is_tensor = is_numpy_array
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
|
||||
Reference in New Issue
Block a user