add a new flax example for Bert model inference (#34794)
* add a new example for flax inference cases * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix for "make fixup" --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
-->
|
-->
|
||||||
|
|
||||||
# Language model training examples
|
# Language model training and inference examples
|
||||||
|
|
||||||
The following example showcases how to train a language model from scratch
|
The following example showcases how to train a language model from scratch
|
||||||
using the JAX/Flax backend.
|
using the JAX/Flax backend.
|
||||||
@@ -542,3 +542,27 @@ python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
|
|||||||
--report_to="tensorboard" \
|
--report_to="tensorboard" \
|
||||||
--save_strategy="no"
|
--save_strategy="no"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Language model inference with bfloat16
|
||||||
|
|
||||||
|
The following example demonstrates performing inference with a language model using the JAX/Flax backend.
|
||||||
|
|
||||||
|
The example script run_bert_flax.py uses bert-base-uncased, and the model is loaded into `FlaxBertModel`.
|
||||||
|
The input data are randomly generated tokens, and the model is also jitted with JAX.
|
||||||
|
By default, it uses float32 precision for inference. To enable bfloat16, add the flag shown in the command below.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 run_bert_flax.py --precision bfloat16
|
||||||
|
> NOTE: For JAX Versions after v0.4.33 or later, users will need to set the below environment variables as a \
|
||||||
|
> temporary workaround to use Bfloat16 datatype. \
|
||||||
|
> This restriction is expected to be removed in future version
|
||||||
|
```bash
|
||||||
|
export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
|
||||||
|
```
|
||||||
|
bfloat16 gives better performance on GPUs and also Intel CPUs (Sapphire Rapids or later) with Advanced Matrix Extension (Intel AMX).
|
||||||
|
By changing the dtype for `FlaxBertModel `to `jax.numpy.bfloat16`, you get the performance benefits of the underlying hardware.
|
||||||
|
```python
|
||||||
|
import jax
|
||||||
|
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=jax.numpy.bfloat16)
|
||||||
|
```
|
||||||
|
Switching from float32 to bfloat16 can increase the speed of an AWS c7i.4xlarge with Intel Sapphire Rapids by more than 2x.
|
||||||
|
|||||||
56
examples/flax/language-modeling/run_bert_flax.py
Normal file
56
examples/flax/language-modeling/run_bert_flax.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import BertConfig, FlaxBertModel
|
||||||
|
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dtype = jax.numpy.float32
|
||||||
|
if args.precision == "bfloat16":
|
||||||
|
dtype = jax.numpy.bfloat16
|
||||||
|
|
||||||
|
VOCAB_SIZE = 30522
|
||||||
|
BS = 32
|
||||||
|
SEQ_LEN = 128
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_data(batch_size=1, seq_length=384):
|
||||||
|
shape = (batch_size, seq_length)
|
||||||
|
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32)
|
||||||
|
token_type_ids = np.ones(shape).astype(np.int32)
|
||||||
|
attention_mask = np.ones(shape).astype(np.int32)
|
||||||
|
return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
|
||||||
|
inputs = get_input_data(BS, SEQ_LEN)
|
||||||
|
config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new")
|
||||||
|
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def func():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
(nwarmup, nbenchmark) = (5, 100)
|
||||||
|
|
||||||
|
# warmpup
|
||||||
|
for _ in range(nwarmup):
|
||||||
|
func()
|
||||||
|
|
||||||
|
# benchmark
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(nbenchmark):
|
||||||
|
func()
|
||||||
|
end = time.time()
|
||||||
|
print(end - start)
|
||||||
|
print(f"Throughput: {((nbenchmark * BS)/(end-start)):.3f} examples/sec")
|
||||||
Reference in New Issue
Block a user