From 31921d8d5e4fa322f1ee3ba2011190bdafd5d304 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 23 Apr 2024 10:42:17 +0100 Subject: [PATCH] Jax: scipy version pin (#30402) scipy pin for jax --- setup.py | 3 ++- src/transformers/dependency_versions_table.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 412c248dc8..37fa11379d 100644 --- a/setup.py +++ b/setup.py @@ -161,6 +161,7 @@ _deps = [ "safetensors>=0.4.1", "sagemaker>=2.31.0", "scikit-learn", + "scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`) "sentencepiece>=0.1.91,!=0.1.92", "sigopt", "starlette", @@ -267,7 +268,7 @@ if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows else: extras["retrieval"] = deps_list("faiss-cpu", "datasets") - extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax") + extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax", "scipy") extras["tokenizers"] = deps_list("tokenizers") extras["ftfy"] = deps_list("ftfy") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index d40cae189a..7f78c8285b 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -67,6 +67,7 @@ deps = { "safetensors": "safetensors>=0.4.1", "sagemaker": "sagemaker>=2.31.0", "scikit-learn": "scikit-learn", + "scipy": "scipy<1.13.0", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sigopt": "sigopt", "starlette": "starlette",