Files
HuggingFace_transformer/src/transformers/trainer_utils.py
Julien Plu aad50151f3 TF version of the trainer (#4017)
* First commit to add a TF version of the trainer.

* Make the TF trainer closer to what looks the PT trainer

* Refactoring common code between the PT and TF trainer into an util file.

* Some bugfix + better similarity with the PT trainer

* Add missing class in transformers init

* Bugfix over prediction + use classification report instead of simple metrics

* Fix name error

* Fix optimization tests + style

* Apply style

* Several bugfix for multi-gpu training

* Apply style

* Apply style

* Add glue example for the TF trainer

* Several bugix + address the reviews

* Fix on the TF training args file

* Add a debug mode

* Bugfix in utils_ner.py when segment_ids is None

* Apply style

* Apply style

* Add TPU strategy

* Fix selection strategy
2020-05-06 12:56:52 -04:00

28 lines
516 B
Python

from typing import Dict, NamedTuple, Optional
import numpy as np
class EvalPrediction(NamedTuple):
"""
Evaluation output (always contains labels), to be used
to compute metrics.
"""
predictions: np.ndarray
label_ids: np.ndarray
class PredictionOutput(NamedTuple):
predictions: np.ndarray
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
class TrainOutput(NamedTuple):
global_step: int
training_loss: float
PREFIX_CHECKPOINT_DIR = "checkpoint"