TF version of the trainer (#4017)
* First commit to add a TF version of the trainer. * Make the TF trainer closer to what looks the PT trainer * Refactoring common code between the PT and TF trainer into an util file. * Some bugfix + better similarity with the PT trainer * Add missing class in transformers init * Bugfix over prediction + use classification report instead of simple metrics * Fix name error * Fix optimization tests + style * Apply style * Several bugfix for multi-gpu training * Apply style * Apply style * Add glue example for the TF trainer * Several bugix + address the reviews * Fix on the TF training args file * Add a debug mode * Bugfix in utils_ner.py when segment_ids is None * Apply style * Apply style * Add TPU strategy * Fix selection strategy
This commit is contained in:
@@ -145,7 +145,9 @@ from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||
from .trainer_utils import EvalPrediction
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
@@ -502,6 +504,9 @@ if is_tf_available():
|
||||
# Optimization
|
||||
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
||||
|
||||
# Trainer
|
||||
from .trainer_tf import TFTrainer
|
||||
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user