add warning if neither pt nor tf are found
This commit is contained in:
@@ -15,12 +15,12 @@ valid_dataset = valid_dataset.batch(64)
|
||||
|
||||
# Compile tf.keras model for training
|
||||
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(2e-5, 345, end_learning_rate=0)
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
tf_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0),
|
||||
loss=loss, metrics=['sparse_categorical_accuracy'])
|
||||
tf_model.compile(optimizer=optimizer, loss=loss, metrics=['sparse_categorical_accuracy'])
|
||||
|
||||
# Train and evaluate using tf.keras.Model.fit()
|
||||
tf_model.fit(train_dataset, epochs=3, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
|
||||
tf_model.fit(train_dataset, epochs=1, steps_per_epoch=115, validation_data=valid_dataset, validation_steps=7)
|
||||
|
||||
# Save the model and load it in PyTorch
|
||||
tf_model.save_pretrained('./runs/')
|
||||
|
||||
Reference in New Issue
Block a user