Tensorflow improvements (#4530)
* Better None gradients handling * Apply Style * Apply Style * Create a loss class per task to compute its respective loss * Add loss classes to the ALBERT TF models * Add loss classes to the BERT TF models * Add question answering and multiple choice to TF Camembert * Remove prints * Add multiple choice model to TF DistilBERT + loss computation * Add question answering model to TF Electra + loss computation * Add token classification, question answering and multiple choice models to TF Flaubert * Add multiple choice model to TF Roberta + loss computation * Add multiple choice model to TF XLM + loss computation * Add multiple choice and question answering models to TF XLM-Roberta * Add multiple choice model to TF XLNet + loss computation * Remove unused parameters * Add task loss classes * Reorder TF imports + add new model classes * Add new model classes * Bugfix in TF T5 model * Bugfix for TF T5 tests * Bugfix in TF T5 model * Fix TF T5 model tests * Fix T5 tests + some renaming * Fix inheritance issue in the AutoX tests * Add tests for TF Flaubert and TF XLM Roberta * Add tests for TF Flaubert and TF XLM Roberta * Remove unused piece of code in the TF trainer * bugfix and remove unused code * Bugfix for TF 2.2 * Apply Style * Divide TFSequenceClassificationAndMultipleChoiceLoss into their two respective name * Apply style * Mirror the PT Trainer in the TF one: fp16, optimizers and tb_writer as class parameter and better dataset handling * Fix TF optimizations tests and apply style * Remove useless parameter * Bugfix and apply style * Fix TF Trainer prediction * Now the TF models return the loss such as their PyTorch couterparts * Apply Style * Ignore some tests output * Take into account the SQuAD cls_index, p_mask and is_impossible parameters for the QuestionAnswering task models. * Fix names for SQuAD data * Apply Style * Fix conflicts with 2.11 release * Fix conflicts with 2.11 * Fix wrongname * Add better documentation on the new create_optimizer function * Fix isort * logging_dir: use same default as PyTorch Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
@@ -27,6 +28,17 @@ def is_tpu_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default_logdir() -> str:
|
||||
"""
|
||||
Same default as PyTorch
|
||||
"""
|
||||
import socket
|
||||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
|
||||
return os.path.join("runs", current_time + "_" + socket.gethostname())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments:
|
||||
"""
|
||||
@@ -97,7 +109,7 @@ class TrainingArguments:
|
||||
)
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
|
||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
|
||||
Reference in New Issue
Block a user