Update README.md
This commit is contained in:
committed by
GitHub
parent
d24a523130
commit
b21905e03d
@@ -234,10 +234,14 @@ datasets["train"] = datasets["train"].select(range(1000))
|
||||
|
||||
## How to install relevant libraries
|
||||
|
||||
In the following we will explain how to install all relevant libraries on your local computer and on TPU VM.
|
||||
|
||||
It is recommended to install all relevant libraries both on your local machine
|
||||
and on the TPU virtual machine. This way, quick prototyping and testing can be done on
|
||||
your local machine and the actual training can be done on the TPU VM.
|
||||
|
||||
### Local computer
|
||||
|
||||
The following libraries are required to train a JAX/Flax model with 🤗 Transformers and 🤗 Datasets:
|
||||
|
||||
- [JAX](https://github.com/google/jax/)
|
||||
@@ -250,37 +254,22 @@ You should install the above libraries in a [virtual environment](https://docs.p
|
||||
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
|
||||
to use and activate it.
|
||||
|
||||
You should be able to run the command:
|
||||
|
||||
```bash
|
||||
python3 -m venv <your-venv-name>
|
||||
```
|
||||
|
||||
You can activate your venv by running
|
||||
|
||||
```bash
|
||||
source ~/<your-venv-name>/bin/activate
|
||||
```
|
||||
|
||||
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
|
||||
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
|
||||
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
|
||||
|
||||
**IMPORTANT**: If you are setting up your environment on a TPU VM, make sure to
|
||||
install JAX's TPU version before cloning and installing the transformers repository.
|
||||
Otherwise, an incorrect version of JAX will be installed, and the following commands will
|
||||
throw an error.
|
||||
To install JAX's TPU version first run the following command:
|
||||
|
||||
```
|
||||
$ pip install requests
|
||||
```
|
||||
|
||||
and then:
|
||||
|
||||
```
|
||||
$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
```
|
||||
|
||||
To verify that JAX was correctly installed, you can run the following command:
|
||||
|
||||
```python
|
||||
import jax
|
||||
jax.device_count()
|
||||
```
|
||||
|
||||
This should display the number of TPU cores, which should be 8 on a TPUv3-8 VM.
|
||||
|
||||
Now you can run the following steps as usual.
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/transformers) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
@@ -352,6 +341,162 @@ model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
|
||||
model(input_ids)
|
||||
```
|
||||
|
||||
### TPU VM
|
||||
|
||||
**VERY IMPORTANT** - Only one process can access the TPU cores at a time. This means that if multiple team members
|
||||
are trying to connect to the TPU cores errors, such as:
|
||||
|
||||
```
|
||||
libtpu.so already in used by another process. Not attempting to load libtpu.so in this process.
|
||||
```
|
||||
|
||||
are thrown. As a conclusion, we recommend every team member to create her/his own virtual environment, but only one
|
||||
person should run the heavy training processes. Also, please take turns when setting up the TPUv3-8 so that everybody
|
||||
can verify that JAX is correctly installed.
|
||||
|
||||
The following libraries are required to train a JAX/Flax model with 🤗 Transformers and 🤗 Datasets on TPU VM:
|
||||
|
||||
- [JAX](https://github.com/google/jax/)
|
||||
- [Flax](https://github.com/google/flax)
|
||||
- [Optax](https://github.com/deepmind/optax)
|
||||
- [Transformers](https://github.com/huggingface/transformers)
|
||||
- [Datasets](https://github.com/huggingface/datasets)
|
||||
|
||||
You should install the above libraries in a [virtual environment](https://docs.python.org/3/library/venv.html).
|
||||
If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). Create a virtual environment with the version of Python you're going
|
||||
to use and activate it.
|
||||
|
||||
You should be able to run the command:
|
||||
|
||||
```bash
|
||||
python3 -m venv <your-venv-name>
|
||||
```
|
||||
|
||||
If this doesn't work, you first might to have install `python3-venv`. You can do this as follows:
|
||||
|
||||
```bash
|
||||
sudo apt-get install python3-venv
|
||||
```
|
||||
|
||||
You can activate your venv by running
|
||||
|
||||
```bash
|
||||
source ~/<your-venv-name>/bin/activate
|
||||
```
|
||||
|
||||
Next you should install JAX's TPU version on TPU by running the following command:
|
||||
|
||||
```
|
||||
$ pip install requests
|
||||
```
|
||||
|
||||
and then:
|
||||
|
||||
```
|
||||
$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
```
|
||||
|
||||
**Note**: Running this command might actually throw an error, such as:
|
||||
```
|
||||
Building wheel for jax (setup.py) ... error
|
||||
ERROR: Command errored out with exit status 1:
|
||||
command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo
|
||||
cwd: /tmp/pip-install-lwseckn1/jax/
|
||||
Complete output (6 lines):
|
||||
usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...]
|
||||
or: setup.py --help [cmd1 cmd2 ...]
|
||||
or: setup.py --help-commands
|
||||
or: setup.py cmd --help
|
||||
|
||||
error: invalid command 'bdist_wheel'
|
||||
----------------------------------------
|
||||
ERROR: Failed building wheel for jax
|
||||
```
|
||||
Jax should have been installed correctly nevertheless.
|
||||
|
||||
To verify that JAX was correctly installed, you can run the following command:
|
||||
|
||||
```python
|
||||
import jax
|
||||
jax.device_count()
|
||||
```
|
||||
|
||||
This should display the number of TPU cores, which should be 8 on a TPUv3-8 VM.
|
||||
|
||||
We strongly recommend to make use of the provided JAX/Flax examples scripts in [transformers/examples/flax](https://github.com/huggingface/transformers/tree/master/examples/flax) even if you want to train a JAX/Flax model of another github repository that is not integrated into 🤗 Transformers.
|
||||
In all likelihood, you will need to adapt one of the example scripts, so we recommend forking and cloning the 🤗 Transformers repository as follows.
|
||||
Doing so will allow you to share your fork of the Transformers library with your team members so that the team effectively works on the same code base. It will also automatically install the newest versions of `flax`, `jax` and `optax`.
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/transformers) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||
|
||||
```bash
|
||||
$ git clone https://github.com/<your Github handle>/transformers.git
|
||||
$ cd transformers
|
||||
$ git remote add upstream https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes. This is especially useful to share code changes with your team:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-project
|
||||
```
|
||||
|
||||
4. Set up a flax environment by running the following command in a virtual environment:
|
||||
|
||||
```bash
|
||||
$ pip install -e ".[flax]"
|
||||
```
|
||||
|
||||
(If transformers was already installed in the virtual environment, remove
|
||||
it with `pip uninstall transformers` before reinstalling it in editable
|
||||
mode with the `-e` flag.)
|
||||
|
||||
If you have already cloned that repo, you might need to `git pull` to get the most recent changes in the `datasets`
|
||||
library.
|
||||
|
||||
Running this command will automatically install `flax`, `jax` and `optax`.
|
||||
|
||||
Next, you should also install the 🤗 Datasets library. We strongly recommend installing the
|
||||
library from source to profit from the most current additions during the community week.
|
||||
|
||||
Simply run the following steps:
|
||||
|
||||
```
|
||||
$ cd ~/
|
||||
$ git clone https://github.com/huggingface/datasets.git
|
||||
$ cd datasets
|
||||
$ pip install -e ".[streaming]"
|
||||
```
|
||||
|
||||
If you plan on contributing a specific dataset during
|
||||
the community week, please fork the datasets repository and follow the instructions
|
||||
[here](https://github.com/huggingface/datasets/blob/master/CONTRIBUTING.md#how-to-create-a-pull-request).
|
||||
|
||||
To verify that all libraries are correctly installed, you can run the following command.
|
||||
It assumes that both `transformers` and `datasets` were installed from master - otherwise
|
||||
datasets streaming will not work correctly.
|
||||
|
||||
```python
|
||||
from transformers import FlaxRobertaModel, RobertaTokenizerFast
|
||||
from datasets import load_dataset
|
||||
import jax
|
||||
|
||||
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
|
||||
|
||||
dummy_input = next(iter(dataset))["text"]
|
||||
|
||||
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
||||
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
|
||||
|
||||
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
|
||||
|
||||
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
|
||||
model(input_ids)
|
||||
```
|
||||
|
||||
## Quickstart flax and jax
|
||||
|
||||
|
||||
Reference in New Issue
Block a user