per_device instead of per_gpu/error thrown when argument unknown (#4618)
* per_device instead of per_gpu/error thrown when argument unknown * [docs] Restore examples.md symlink * Correct absolute links so that symlink to the doc works correctly * Update src/transformers/hf_argparser.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Warning + reorder * Docs * Style * not for squad Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -126,6 +126,9 @@ class HfArgumentParser(ArgumentParser):
|
||||
if return_remaining_strings:
|
||||
return (*outputs, remaining_args)
|
||||
else:
|
||||
if remaining_args:
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
|
||||
|
||||
return (*outputs,)
|
||||
|
||||
def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:
|
||||
|
||||
@@ -416,7 +416,7 @@ class Trainer:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
|
||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
|
||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
@@ -58,8 +58,28 @@ class TrainingArguments:
|
||||
default=False, metadata={"help": "Run evaluation during training at each logging step."},
|
||||
)
|
||||
|
||||
per_gpu_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for training."})
|
||||
per_gpu_eval_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU/CPU for evaluation."})
|
||||
per_device_train_batch_size: int = field(
|
||||
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
||||
)
|
||||
per_device_eval_batch_size: int = field(
|
||||
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||
)
|
||||
|
||||
per_gpu_train_batch_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
|
||||
"Batch size per GPU/TPU core/CPU for training."
|
||||
},
|
||||
)
|
||||
per_gpu_eval_batch_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Deprecated, the use of `--per_device_eval_batch_size` is preferred."
|
||||
"Batch size per GPU/TPU core/CPU for evaluation."
|
||||
},
|
||||
)
|
||||
|
||||
gradient_accumulation_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
|
||||
@@ -115,11 +135,23 @@ class TrainingArguments:
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
return self.per_gpu_train_batch_size * max(1, self.n_gpu)
|
||||
if self.per_gpu_train_batch_size:
|
||||
logger.warning(
|
||||
"Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
|
||||
"version. Using `--per_device_train_batch_size` is preferred."
|
||||
)
|
||||
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
|
||||
return per_device_batch_size * max(1, self.n_gpu)
|
||||
|
||||
@property
|
||||
def eval_batch_size(self) -> int:
|
||||
return self.per_gpu_eval_batch_size * max(1, self.n_gpu)
|
||||
if self.per_gpu_eval_batch_size:
|
||||
logger.warning(
|
||||
"Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
|
||||
"version. Using `--per_device_eval_batch_size` is preferred."
|
||||
)
|
||||
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
|
||||
return per_device_batch_size * max(1, self.n_gpu)
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
|
||||
Reference in New Issue
Block a user