From 8b5e5ebcf9b413f7c0928141ae9719ac3d47caf3 Mon Sep 17 00:00:00 2001 From: Suraj Parmar Date: Fri, 1 May 2020 07:44:08 +0530 Subject: [PATCH] Continue training args and tqdm in notebooks (#3939) * Continue training args * Continue training args * added explaination * added explaination * added explaination * Fixed tqdm auto * Update src/transformers/training_args.py Co-Authored-By: Julien Chaumond * Update src/transformers/training_args.py * Update src/transformers/training_args.py Co-authored-by: Julien Chaumond --- src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 24 ++++++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6524ba42ab..a7b2bae457 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -15,7 +15,7 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler -from tqdm import tqdm, trange +from tqdm.auto import tqdm, trange from .data.data_collator import DataCollator, DefaultDataCollator from .modeling_utils import PreTrainedModel diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index af32eac25b..c4bc9b6456 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -29,20 +29,27 @@ class TrainingArguments: metadata={"help": "The output directory where the model predictions and checkpoints will be written."} ) overwrite_output_dir: bool = field( - default=False, metadata={"help": "Overwrite the content of the output directory"} + default=False, + metadata={ + "help": ( + "Overwrite the content of the output directory." + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) evaluate_during_training: bool = field( - default=False, metadata={"help": "Run evaluation during training at each logging step."} + default=False, metadata={"help": "Run evaluation during training at each logging step."}, ) per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."}) per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."}) gradient_accumulation_steps: int = field( - default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."} + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, ) learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) @@ -64,7 +71,10 @@ class TrainingArguments: save_total_limit: Optional[int] = field( default=None, metadata={ - "help": "Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default" + "help": ( + "Limit the total amount of checkpoints." + "Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints" + ) }, ) no_cuda: bool = field(default=False, metadata={"help": "Avoid using CUDA even if it is available"}) @@ -77,8 +87,10 @@ class TrainingArguments: fp16_opt_level: str = field( default="O1", metadata={ - "help": "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." - "See details at https://nvidia.github.io/apex/amp.html" + "help": ( + "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html" + ) }, ) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})