Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
This commit is contained in:
@@ -204,6 +204,8 @@ if is_tf_available():
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user