Unpin numba (#23162)
* fix for ragged list * unpin numba * make style * np.object -> object * propagate changes to tokenizer as well * np.long -> "long" * revert tokenization changes * check with tokenization changes * list/tuple logic * catch numpy * catch else case * clean up * up * better check * trigger ci * Empty commit to trigger CI
This commit is contained in:
4
setup.py
4
setup.py
@@ -132,7 +132,6 @@ _deps = [
|
|||||||
"librosa",
|
"librosa",
|
||||||
"nltk",
|
"nltk",
|
||||||
"natten>=0.14.6",
|
"natten>=0.14.6",
|
||||||
"numba<0.57.0", # Can be removed once unpinned.
|
|
||||||
"numpy>=1.17",
|
"numpy>=1.17",
|
||||||
"onnxconverter-common",
|
"onnxconverter-common",
|
||||||
"onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools>=1.4.2",
|
||||||
@@ -286,8 +285,7 @@ extras["sigopt"] = deps_list("sigopt")
|
|||||||
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
|
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
|
||||||
|
|
||||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||||
# numba can be removed here once unpinned
|
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
|
||||||
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm", "numba")
|
|
||||||
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
|
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
|
||||||
extras["speech"] = deps_list("torchaudio") + extras["audio"]
|
extras["speech"] = deps_list("torchaudio") + extras["audio"]
|
||||||
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
|
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ deps = {
|
|||||||
"librosa": "librosa",
|
"librosa": "librosa",
|
||||||
"nltk": "nltk",
|
"nltk": "nltk",
|
||||||
"natten": "natten>=0.14.6",
|
"natten": "natten>=0.14.6",
|
||||||
"numba": "numba<0.57.0",
|
|
||||||
"numpy": "numpy>=1.17",
|
"numpy": "numpy>=1.17",
|
||||||
"onnxconverter-common": "onnxconverter-common",
|
"onnxconverter-common": "onnxconverter-common",
|
||||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||||
|
|||||||
@@ -156,7 +156,15 @@ class BatchFeature(UserDict):
|
|||||||
as_tensor = jnp.array
|
as_tensor = jnp.array
|
||||||
is_tensor = is_jax_tensor
|
is_tensor = is_jax_tensor
|
||||||
else:
|
else:
|
||||||
as_tensor = np.asarray
|
|
||||||
|
def as_tensor(value, dtype=None):
|
||||||
|
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
|
||||||
|
value_lens = [len(val) for val in value]
|
||||||
|
if len(set(value_lens)) > 1 and dtype is None:
|
||||||
|
# we have a ragged list so handle explicitly
|
||||||
|
value = as_tensor([np.asarray(val) for val in value], dtype=object)
|
||||||
|
return np.asarray(value, dtype=dtype)
|
||||||
|
|
||||||
is_tensor = is_numpy_array
|
is_tensor = is_numpy_array
|
||||||
|
|
||||||
# Do the tensor conversion in batch
|
# Do the tensor conversion in batch
|
||||||
|
|||||||
@@ -705,7 +705,15 @@ class BatchEncoding(UserDict):
|
|||||||
as_tensor = jnp.array
|
as_tensor = jnp.array
|
||||||
is_tensor = is_jax_tensor
|
is_tensor = is_jax_tensor
|
||||||
else:
|
else:
|
||||||
as_tensor = np.asarray
|
|
||||||
|
def as_tensor(value, dtype=None):
|
||||||
|
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
|
||||||
|
value_lens = [len(val) for val in value]
|
||||||
|
if len(set(value_lens)) > 1 and dtype is None:
|
||||||
|
# we have a ragged list so handle explicitly
|
||||||
|
value = as_tensor([np.asarray(val) for val in value], dtype=object)
|
||||||
|
return np.asarray(value, dtype=dtype)
|
||||||
|
|
||||||
is_tensor = is_numpy_array
|
is_tensor = is_numpy_array
|
||||||
|
|
||||||
# Do the tensor conversion in batch
|
# Do the tensor conversion in batch
|
||||||
|
|||||||
@@ -392,7 +392,7 @@ class RealmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
b"This is the fourth record.",
|
b"This is the fourth record.",
|
||||||
b"This is the fifth record.",
|
b"This is the fifth record.",
|
||||||
],
|
],
|
||||||
dtype=np.object,
|
dtype=object,
|
||||||
)
|
)
|
||||||
retriever = RealmRetriever(block_records, tokenizer)
|
retriever = RealmRetriever(block_records, tokenizer)
|
||||||
model = RealmForOpenQA(openqa_config, retriever)
|
model = RealmForOpenQA(openqa_config, retriever)
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class RealmRetrieverTest(TestCase):
|
|||||||
b"This is the fifth record",
|
b"This is the fifth record",
|
||||||
b"This is a longer longer longer record",
|
b"This is a longer longer longer record",
|
||||||
],
|
],
|
||||||
dtype=np.object,
|
dtype=object,
|
||||||
)
|
)
|
||||||
return block_records
|
return block_records
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class RealmRetrieverTest(TestCase):
|
|||||||
retriever = self.get_dummy_retriever()
|
retriever = self.get_dummy_retriever()
|
||||||
tokenizer = retriever.tokenizer
|
tokenizer = retriever.tokenizer
|
||||||
|
|
||||||
retrieved_block_ids = np.array([0, 3], dtype=np.long)
|
retrieved_block_ids = np.array([0, 3], dtype="long")
|
||||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||||
answer_ids = tokenizer(
|
answer_ids = tokenizer(
|
||||||
["the fourth"],
|
["the fourth"],
|
||||||
@@ -151,7 +151,7 @@ class RealmRetrieverTest(TestCase):
|
|||||||
retriever = self.get_dummy_retriever()
|
retriever = self.get_dummy_retriever()
|
||||||
tokenizer = retriever.tokenizer
|
tokenizer = retriever.tokenizer
|
||||||
|
|
||||||
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
|
retrieved_block_ids = np.array([0, 3, 5], dtype="long")
|
||||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||||
answer_ids = tokenizer(
|
answer_ids = tokenizer(
|
||||||
["the fourth", "longer longer"],
|
["the fourth", "longer longer"],
|
||||||
|
|||||||
Reference in New Issue
Block a user