From f82b19cb6f95af3d5d94bb8d1971e2203968f9ac Mon Sep 17 00:00:00 2001 From: Louie Tsai Date: Tue, 21 Jan 2025 05:09:29 -0800 Subject: [PATCH] add a new flax example for Bert model inference (#34794) * add a new example for flax inference cases * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix for "make fixup" --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- examples/flax/language-modeling/README.md | 26 ++++++++- .../flax/language-modeling/run_bert_flax.py | 56 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 examples/flax/language-modeling/run_bert_flax.py diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index 10a2a02f7f..9e2dee3621 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. --> -# Language model training examples +# Language model training and inference examples The following example showcases how to train a language model from scratch using the JAX/Flax backend. @@ -542,3 +542,27 @@ python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \ --report_to="tensorboard" \ --save_strategy="no" ``` + +## Language model inference with bfloat16 + +The following example demonstrates performing inference with a language model using the JAX/Flax backend. + +The example script run_bert_flax.py uses bert-base-uncased, and the model is loaded into `FlaxBertModel`. +The input data are randomly generated tokens, and the model is also jitted with JAX. +By default, it uses float32 precision for inference. To enable bfloat16, add the flag shown in the command below. + +```bash +python3 run_bert_flax.py --precision bfloat16 +> NOTE: For JAX Versions after v0.4.33 or later, users will need to set the below environment variables as a \ +> temporary workaround to use Bfloat16 datatype. \ +> This restriction is expected to be removed in future version +```bash +export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false +``` +bfloat16 gives better performance on GPUs and also Intel CPUs (Sapphire Rapids or later) with Advanced Matrix Extension (Intel AMX). +By changing the dtype for `FlaxBertModel `to `jax.numpy.bfloat16`, you get the performance benefits of the underlying hardware. +```python +import jax +model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=jax.numpy.bfloat16) +``` +Switching from float32 to bfloat16 can increase the speed of an AWS c7i.4xlarge with Intel Sapphire Rapids by more than 2x. diff --git a/examples/flax/language-modeling/run_bert_flax.py b/examples/flax/language-modeling/run_bert_flax.py new file mode 100644 index 0000000000..2e73af4592 --- /dev/null +++ b/examples/flax/language-modeling/run_bert_flax.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +import time +from argparse import ArgumentParser + +import jax +import numpy as np + +from transformers import BertConfig, FlaxBertModel + + +parser = ArgumentParser() +parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32") +args = parser.parse_args() + +dtype = jax.numpy.float32 +if args.precision == "bfloat16": + dtype = jax.numpy.bfloat16 + +VOCAB_SIZE = 30522 +BS = 32 +SEQ_LEN = 128 + + +def get_input_data(batch_size=1, seq_length=384): + shape = (batch_size, seq_length) + input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32) + token_type_ids = np.ones(shape).astype(np.int32) + attention_mask = np.ones(shape).astype(np.int32) + return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} + + +inputs = get_input_data(BS, SEQ_LEN) +config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new") +model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype) + + +@jax.jit +def func(): + outputs = model(**inputs) + return outputs + + +(nwarmup, nbenchmark) = (5, 100) + +# warmpup +for _ in range(nwarmup): + func() + +# benchmark + +start = time.time() +for _ in range(nbenchmark): + func() +end = time.time() +print(end - start) +print(f"Throughput: {((nbenchmark * BS)/(end-start)):.3f} examples/sec")