diff --git a/setup.py b/setup.py index b9afe24852..2a0ef33b88 100644 --- a/setup.py +++ b/setup.py @@ -124,8 +124,8 @@ _deps = [ "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", - "jax>=0.2.8,!=0.3.2,<=0.4.13", - "jaxlib>=0.1.65,<=0.4.13", + "jax>=0.4.1,<=0.4.13", + "jaxlib>=0.4.1,<=0.4.13", "jieba", "kenlm", "keras-nlp>=0.3.1", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 350c312134..cce16d66da 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -29,8 +29,8 @@ deps = { "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2,<=0.4.13", - "jaxlib": "jaxlib>=0.1.65,<=0.4.13", + "jax": "jax>=0.4.1,<=0.4.13", + "jaxlib": "jaxlib>=0.4.1,<=0.4.13", "jieba": "jieba", "kenlm": "kenlm", "keras-nlp": "keras-nlp>=0.3.1",