Move prediction_loss_only to TrainingArguments (#6426)

This commit is contained in:
Sylvain Gugger
2020-08-12 08:03:45 -04:00
committed by GitHub
parent e9c3031463
commit 34fabe1697
5 changed files with 37 additions and 8 deletions

View File

@@ -159,6 +159,8 @@ class Trainer:
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
kwargs:
Deprecated keyword arguments.
"""
def __init__(
@@ -169,9 +171,9 @@ class Trainer:
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs,
):
self.model = model.to(args.device)
self.args = args
@@ -179,9 +181,16 @@ class Trainer:
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.optimizer, self.lr_scheduler = optimizers
self.tb_writer = tb_writer
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
FutureWarning,
)
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available():
@@ -951,7 +960,9 @@ class Trainer:
)
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
model = self.model
# multi-gpu eval

View File

@@ -44,8 +44,6 @@ class TFTrainer:
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
When performing evaluation and predictions, only returns the loss.
tb_writer (:obj:`tf.summary.SummaryWriter`, `optional`):
Object to write to TensorBoard.
optimizers (:obj:`Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]`, `optional`):
@@ -54,6 +52,8 @@ class TFTrainer:
:class:`~transformers.AdamWeightDecay`. The scheduler will default to an instance of
:class:`tf.keras.optimizers.schedules.PolynomialDecay` if :obj:`args.num_warmup_steps` is 0 else
an instance of :class:`~transformers.WarmUp`.
kwargs:
Deprecated keyword arguments.
"""
def __init__(
@@ -63,12 +63,12 @@ class TFTrainer:
train_dataset: Optional[tf.data.Dataset] = None,
eval_dataset: Optional[tf.data.Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
tb_writer: Optional[tf.summary.SummaryWriter] = None,
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
None,
None,
),
**kwargs,
):
assert parse(tf.__version__).release >= (2, 2, 0), (
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
@@ -80,11 +80,17 @@ class TFTrainer:
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.optimizer, self.lr_scheduler = optimizers
self.gradient_accumulator = GradientAccumulator()
self.global_step = 0
self.epoch_logging = 0
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
FutureWarning,
)
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
if tb_writer is not None:
self.tb_writer = tb_writer
@@ -282,7 +288,9 @@ class TFTrainer:
dataset, steps, num_examples, description, prediction_loss_only=prediction_loss_only
)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", num_examples)

View File

@@ -52,6 +52,8 @@ class TrainingArguments:
Whether to run predictions on the test set or not.
evaluate_during_training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run evaluation during training at each logging step or not.
prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
When performing evaluation and predictions, only returns the loss.
per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8):
The batch size per GPU/TPU core/CPU for training.
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8):
@@ -132,6 +134,9 @@ class TrainingArguments:
evaluate_during_training: bool = field(
default=False, metadata={"help": "Run evaluation during training at each logging step."},
)
prediction_loss_only: bool = field(
default=False, metadata={"help": "When performing evaluation and predictions, only returns the loss."},
)
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}