TF version of the trainer (#4017)
* First commit to add a TF version of the trainer. * Make the TF trainer closer to what looks the PT trainer * Refactoring common code between the PT and TF trainer into an util file. * Some bugfix + better similarity with the PT trainer * Add missing class in transformers init * Bugfix over prediction + use classification report instead of simple metrics * Fix name error * Fix optimization tests + style * Apply style * Several bugfix for multi-gpu training * Apply style * Apply style * Add glue example for the TF trainer * Several bugix + address the reviews * Fix on the TF training args file * Add a debug mode * Bugfix in utils_ner.py when segment_ids is None * Apply style * Apply style * Add TPU strategy * Fix selection strategy
This commit is contained in:
27
src/transformers/trainer_utils.py
Normal file
27
src/transformers/trainer_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
"""
|
||||
Evaluation output (always contains labels), to be used
|
||||
to compute metrics.
|
||||
"""
|
||||
|
||||
predictions: np.ndarray
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class TrainOutput(NamedTuple):
|
||||
global_step: int
|
||||
training_loss: float
|
||||
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||
Reference in New Issue
Block a user