diff --git a/setup.py b/setup.py index a3100d6415..9cddf97c55 100644 --- a/setup.py +++ b/setup.py @@ -107,6 +107,7 @@ _deps = [ "jax>=0.2.8", "jaxlib>=0.1.65", "jieba", + "keras!=2.7.0", # Remove when they fix their release "keras2onnx", "nltk", "numpy>=1.17", @@ -227,8 +228,8 @@ extras = {} extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic") extras["sklearn"] = deps_list("scikit-learn") -extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx") -extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx") +extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx", "keras") +extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx", "keras") extras["torch"] = deps_list("torch") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 786e5a5691..2997951df3 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -25,6 +25,7 @@ deps = { "jax": "jax>=0.2.8", "jaxlib": "jaxlib>=0.1.65", "jieba": "jieba", + "keras": "keras!=2.7.0", "keras2onnx": "keras2onnx", "nltk": "nltk", "numpy": "numpy>=1.17",