add warning if neither pt nor tf are found
This commit is contained in:
@@ -15,12 +15,12 @@ valid_dataset = valid_dataset.batch(64)
|
||||
|
||||
# Compile tf.keras model for training
|
||||
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(2e-5, 345, end_learning_rate=0)
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
tf_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0),
|
||||
loss=loss, metrics=['sparse_categorical_accuracy'])
|
||||
tf_model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
|
||||
|
||||
# Train and evaluate using tf.keras.Model.fit()
|
||||
tf_model.fit(train_dataset, epochs=3, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
|
||||
tf_model.fit(train_dataset, epochs=1, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
|
||||
|
||||
# Save the model and load it in PyTorch
|
||||
tf_model.save_pretrained('./runs/')
|
||||
|
||||
@@ -157,3 +157,8 @@ if is_tf_available() and is_torch_available():
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model)
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
||||
"Models won't be available and only tokenizers, configuration"
|
||||
"and file/data utilities can be used.")
|
||||
|
||||
Reference in New Issue
Block a user