add a new flax example for Bert model inference (#34794)
* add a new example for flax inference cases * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update examples/flax/language-modeling/README.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix for "make fixup" --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Language model training examples
|
||||
# Language model training and inference examples
|
||||
|
||||
The following example showcases how to train a language model from scratch
|
||||
using the JAX/Flax backend.
|
||||
@@ -542,3 +542,27 @@ python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
|
||||
--report_to="tensorboard" \
|
||||
--save_strategy="no"
|
||||
```
|
||||
|
||||
## Language model inference with bfloat16
|
||||
|
||||
The following example demonstrates performing inference with a language model using the JAX/Flax backend.
|
||||
|
||||
The example script run_bert_flax.py uses bert-base-uncased, and the model is loaded into `FlaxBertModel`.
|
||||
The input data are randomly generated tokens, and the model is also jitted with JAX.
|
||||
By default, it uses float32 precision for inference. To enable bfloat16, add the flag shown in the command below.
|
||||
|
||||
```bash
|
||||
python3 run_bert_flax.py --precision bfloat16
|
||||
> NOTE: For JAX Versions after v0.4.33 or later, users will need to set the below environment variables as a \
|
||||
> temporary workaround to use Bfloat16 datatype. \
|
||||
> This restriction is expected to be removed in future version
|
||||
```bash
|
||||
export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
|
||||
```
|
||||
bfloat16 gives better performance on GPUs and also Intel CPUs (Sapphire Rapids or later) with Advanced Matrix Extension (Intel AMX).
|
||||
By changing the dtype for `FlaxBertModel `to `jax.numpy.bfloat16`, you get the performance benefits of the underlying hardware.
|
||||
```python
|
||||
import jax
|
||||
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=jax.numpy.bfloat16)
|
||||
```
|
||||
Switching from float32 to bfloat16 can increase the speed of an AWS c7i.4xlarge with Intel Sapphire Rapids by more than 2x.
|
||||
|
||||
56
examples/flax/language-modeling/run_bert_flax.py
Normal file
56
examples/flax/language-modeling/run_bert_flax.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertConfig, FlaxBertModel
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32")
|
||||
args = parser.parse_args()
|
||||
|
||||
dtype = jax.numpy.float32
|
||||
if args.precision == "bfloat16":
|
||||
dtype = jax.numpy.bfloat16
|
||||
|
||||
VOCAB_SIZE = 30522
|
||||
BS = 32
|
||||
SEQ_LEN = 128
|
||||
|
||||
|
||||
def get_input_data(batch_size=1, seq_length=384):
|
||||
shape = (batch_size, seq_length)
|
||||
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32)
|
||||
token_type_ids = np.ones(shape).astype(np.int32)
|
||||
attention_mask = np.ones(shape).astype(np.int32)
|
||||
return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
inputs = get_input_data(BS, SEQ_LEN)
|
||||
config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new")
|
||||
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def func():
|
||||
outputs = model(**inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
(nwarmup, nbenchmark) = (5, 100)
|
||||
|
||||
# warmpup
|
||||
for _ in range(nwarmup):
|
||||
func()
|
||||
|
||||
# benchmark
|
||||
|
||||
start = time.time()
|
||||
for _ in range(nbenchmark):
|
||||
func()
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
print(f"Throughput: {((nbenchmark * BS)/(end-start)):.3f} examples/sec")
|
||||
Reference in New Issue
Block a user