Jax: scipy version pin (#30402)

scipy pin for jax
This commit is contained in:
Joao Gante
2024-04-23 10:42:17 +01:00
committed by GitHub
parent 2d61823fa2
commit 31921d8d5e
2 changed files with 3 additions and 1 deletions

View File

@@ -161,6 +161,7 @@ _deps = [
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"scikit-learn",
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
"sentencepiece>=0.1.91,!=0.1.92",
"sigopt",
"starlette",
@@ -267,7 +268,7 @@ if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax", "scipy")
extras["tokenizers"] = deps_list("tokenizers")
extras["ftfy"] = deps_list("ftfy")