TF version of the trainer (#4017)

* First commit to add a TF version of the trainer.

* Make the TF trainer closer to what looks the PT trainer

* Refactoring common code between the PT and TF trainer into an util file.

* Some bugfix + better similarity with the PT trainer

* Add missing class in transformers init

* Bugfix over prediction + use classification report instead of simple metrics

* Fix name error

* Fix optimization tests + style

* Apply style

* Several bugfix for multi-gpu training

* Apply style

* Apply style

* Add glue example for the TF trainer

* Several bugix + address the reviews

* Fix on the TF training args file

* Add a debug mode

* Bugfix in utils_ner.py when segment_ids is None

* Apply style

* Apply style

* Add TPU strategy

* Fix selection strategy
This commit is contained in:
Julien Plu
2020-05-06 18:56:52 +02:00
committed by GitHub
parent 25296b12aa
commit aad50151f3
10 changed files with 1206 additions and 819 deletions

View File

@@ -145,7 +145,9 @@ from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .trainer_utils import EvalPrediction
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -502,6 +504,9 @@ if is_tf_available():
# Optimization
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
# Trainer
from .trainer_tf import TFTrainer
if not is_tf_available() and not is_torch_available():
logger.warning(