From 66c240f3c950612fa05b2e14c85d4b86c88e473e Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:05:02 +0100 Subject: [PATCH] [JAX] Bump min version (#25286) * [JAX] Bump min version * make fixup --- setup.py | 4 ++-- src/transformers/dependency_versions_table.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index b9afe24852..2a0ef33b88 100644 --- a/setup.py +++ b/setup.py @@ -124,8 +124,8 @@ _deps = [ "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", - "jax>=0.2.8,!=0.3.2,<=0.4.13", - "jaxlib>=0.1.65,<=0.4.13", + "jax>=0.4.1,<=0.4.13", + "jaxlib>=0.4.1,<=0.4.13", "jieba", "kenlm", "keras-nlp>=0.3.1", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 350c312134..cce16d66da 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -29,8 +29,8 @@ deps = { "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2,<=0.4.13", - "jaxlib": "jaxlib>=0.1.65,<=0.4.13", + "jax": "jax>=0.4.1,<=0.4.13", + "jaxlib": "jaxlib>=0.4.1,<=0.4.13", "jieba": "jieba", "kenlm": "kenlm", "keras-nlp": "keras-nlp>=0.3.1",