Added automatic mixed precision and XLA options to run_tf_glue.py
This commit is contained in:
@@ -6,6 +6,11 @@ from transformers import BertTokenizer, TFBertForSequenceClassification, glue_co
|
|||||||
# script parameters
|
# script parameters
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
||||||
|
USE_XLA = False
|
||||||
|
USE_AMP = False
|
||||||
|
|
||||||
|
tf.config.optimizer.set_jit(USE_XLA)
|
||||||
|
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
|
||||||
|
|
||||||
# Load tokenizer and model from pretrained model/vocabulary
|
# Load tokenizer and model from pretrained model/vocabulary
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
@@ -23,10 +28,13 @@ train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
|
|||||||
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)
|
||||||
|
|
||||||
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
|
||||||
|
if USE_AMP:
|
||||||
|
# loss scaling is currently required when using mixed precision
|
||||||
|
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
|
||||||
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
model.compile(optimizer=opt, loss=loss, metrics=[metric])
|
||||||
|
|
||||||
# Train and evaluate using tf.keras.Model.fit()
|
# Train and evaluate using tf.keras.Model.fit()
|
||||||
train_steps = train_examples//BATCH_SIZE
|
train_steps = train_examples//BATCH_SIZE
|
||||||
|
|||||||
Reference in New Issue
Block a user