* First commit to add a TF version of the trainer. * Make the TF trainer closer to what looks the PT trainer * Refactoring common code between the PT and TF trainer into an util file. * Some bugfix + better similarity with the PT trainer * Add missing class in transformers init * Bugfix over prediction + use classification report instead of simple metrics * Fix name error * Fix optimization tests + style * Apply style * Several bugfix for multi-gpu training * Apply style * Apply style * Add glue example for the TF trainer * Several bugix + address the reviews * Fix on the TF training args file * Add a debug mode * Bugfix in utils_ner.py when segment_ids is None * Apply style * Apply style * Add TPU strategy * Fix selection strategy
28 lines
516 B
Python
28 lines
516 B
Python
from typing import Dict, NamedTuple, Optional
|
|
|
|
import numpy as np
|
|
|
|
|
|
class EvalPrediction(NamedTuple):
|
|
"""
|
|
Evaluation output (always contains labels), to be used
|
|
to compute metrics.
|
|
"""
|
|
|
|
predictions: np.ndarray
|
|
label_ids: np.ndarray
|
|
|
|
|
|
class PredictionOutput(NamedTuple):
|
|
predictions: np.ndarray
|
|
label_ids: Optional[np.ndarray]
|
|
metrics: Optional[Dict[str, float]]
|
|
|
|
|
|
class TrainOutput(NamedTuple):
|
|
global_step: int
|
|
training_loss: float
|
|
|
|
|
|
PREFIX_CHECKPOINT_DIR = "checkpoint"
|