Flax T5 (#12150)
* copy pytorch-t5 * init * boom boom * forward pass same * make generation work * add more tests * make test work * finish normal tests * make fix-copies * finish quality * correct slow example * correct slow test * version table * upload models * Update tests/test_modeling_flax_t5.py * correct incorrectly deleted line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
5
setup.py
5
setup.py
@@ -114,6 +114,7 @@ _deps = [
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
"optuna",
|
||||
"optax>=0.0.8",
|
||||
"packaging",
|
||||
"parameterized",
|
||||
"protobuf",
|
||||
@@ -234,7 +235,7 @@ if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
|
||||
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
||||
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
|
||||
|
||||
extras["tokenizers"] = deps_list("tokenizers")
|
||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||
@@ -325,7 +326,7 @@ install_requires = [
|
||||
deps["huggingface-hub"],
|
||||
deps["numpy"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["pyyaml"], # used for the model cards metadata
|
||||
deps["pyyaml"], # used for the model cards metadata
|
||||
deps["regex"], # for OpenAI GPT
|
||||
deps["requests"], # for downloading models over HTTPS
|
||||
deps["sacremoses"], # for XLM
|
||||
|
||||
Reference in New Issue
Block a user