[examples/Flax] add a section about GPUs (#15198)
* add a section about GPUs * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -51,6 +51,15 @@ Consider applying for the [Google TPU Research Cloud project](https://sites.rese
|
|||||||
Each example README contains more details on the specific model and training
|
Each example README contains more details on the specific model and training
|
||||||
procedure.
|
procedure.
|
||||||
|
|
||||||
|
|
||||||
|
## Running on single or multiple GPUs
|
||||||
|
|
||||||
|
All of our JAX/Flax examples also run efficiently on single and multiple GPUs. You can use the same instructions in the README to launch training on GPU.
|
||||||
|
Distributed training is supported out-of-the box and scripts will use all the GPUs that are detected.
|
||||||
|
|
||||||
|
You should follow this [guide for installing JAX on GPUs](https://github.com/google/jax/#pip-installation-gpu-cuda) since the installation depends on
|
||||||
|
your CUDA and CuDNN version.
|
||||||
|
|
||||||
## Supported models
|
## Supported models
|
||||||
|
|
||||||
Porting models from PyTorch to JAX/Flax is an ongoing effort.
|
Porting models from PyTorch to JAX/Flax is an ongoing effort.
|
||||||
|
|||||||
Reference in New Issue
Block a user