diff --git a/setup.py b/setup.py index 4d4d38fc4d..7ffacbff35 100644 --- a/setup.py +++ b/setup.py @@ -89,9 +89,12 @@ extras["tf-cpu"] = [ ] extras["torch"] = ["torch>=1.0"] -extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"] if os.name == "nt": # windows + extras["retrieval"] = ["datasets"] # faiss is not supported on windows extras["flax"] = [] # jax is not supported on windows +else: + extras["retrieval"] = ["faiss-cpu", "datasets"] + extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"] extras["tokenizers"] = ["tokenizers==0.9.2"] extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]