From 87918d322145118af81809ec3b458b42bc730f3f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 31 Jan 2022 19:20:53 +0100 Subject: [PATCH] [examples/Flax] add a section about GPUs (#15198) * add a section about GPUs * Apply suggestions from code review Co-authored-by: Patrick von Platen Co-authored-by: Patrick von Platen --- examples/flax/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/flax/README.md b/examples/flax/README.md index 634537c56e..6209ffb6f7 100644 --- a/examples/flax/README.md +++ b/examples/flax/README.md @@ -51,6 +51,15 @@ Consider applying for the [Google TPU Research Cloud project](https://sites.rese Each example README contains more details on the specific model and training procedure. + +## Running on single or multiple GPUs + +All of our JAX/Flax examples also run efficiently on single and multiple GPUs. You can use the same instructions in the README to launch training on GPU. +Distributed training is supported out-of-the box and scripts will use all the GPUs that are detected. + +You should follow this [guide for installing JAX on GPUs](https://github.com/google/jax/#pip-installation-gpu-cuda) since the installation depends on +your CUDA and CuDNN version. + ## Supported models Porting models from PyTorch to JAX/Flax is an ongoing effort.