Fix some typos. (#17560)
* Fix some typos. Signed-off-by: Yulv-git <yulvchi@qq.com> * Fix typo. Signed-off-by: Yulv-git <yulvchi@qq.com> * make fixup.
This commit is contained in:
@@ -22,7 +22,7 @@ the JAX/Flax backend and the [`pjit`](https://jax.readthedocs.io/en/latest/jax.e
|
||||
> Note: The example is experimental and might have bugs. Also currently it only supports single V3-8.
|
||||
|
||||
The `partition.py` file defines the `PyTree` of `ParitionSpec` for the GPTNeo model which describes how the model will be sharded.
|
||||
The actual sharding is auto-matically handled by `pjit`. The weights are sharded accross all local devices.
|
||||
The actual sharding is auto-matically handled by `pjit`. The weights are sharded across all local devices.
|
||||
To adapt the script for other models, we need to also change the `ParitionSpec` accordingly.
|
||||
|
||||
TODO: Add more explantion.
|
||||
|
||||
Reference in New Issue
Block a user