Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -185,11 +184,6 @@ def main():
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(seq_len):
|
||||
if label_ids[i, j] == -1:
|
||||
label_ids[i, j] = -100
|
||||
warnings.warn(
|
||||
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
|
||||
)
|
||||
if label_ids[i, j] != -100:
|
||||
out_label_list[i].append(label_map[label_ids[i][j]])
|
||||
preds_list[i].append(label_map[preds[i][j]])
|
||||
|
||||
@@ -146,7 +146,7 @@ if is_tf_available():
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
pad_token_label_id: int = -1
|
||||
pad_token_label_id: int = -100
|
||||
# Use cross entropy ignore_index as padding label id so that only
|
||||
# real label ids contribute to the loss later.
|
||||
|
||||
@@ -221,6 +221,8 @@ if is_tf_available():
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user