Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# Examples
|
# Examples
|
||||||
|
|
||||||
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
|
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
|
||||||
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
|
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
|
||||||
|
|
||||||
Here is the list of all our examples:
|
Here is the list of all our examples:
|
||||||
- **grouped by task** (all official examples work for multiple models)
|
- **grouped by task** (all official examples work for multiple models)
|
||||||
|
|||||||
@@ -204,6 +204,8 @@ if is_tf_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self):
|
||||||
|
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||||
|
|
||||||
return self.dataset
|
return self.dataset
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ import os
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -68,6 +70,7 @@ class DataTrainingArguments:
|
|||||||
data_dir: Optional[str] = field(
|
data_dir: Optional[str] = field(
|
||||||
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
||||||
)
|
)
|
||||||
|
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -170,7 +173,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
if not data_args.data_dir:
|
if data_args.use_tfds:
|
||||||
if data_args.version_2_with_negative:
|
if data_args.version_2_with_negative:
|
||||||
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
|
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
|
||||||
|
|
||||||
@@ -179,7 +182,7 @@ def main():
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||||
|
|
||||||
tfds_examples = tfds.load("squad")
|
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
|
||||||
train_examples = (
|
train_examples = (
|
||||||
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
|
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
|
||||||
if training_args.do_train
|
if training_args.do_train
|
||||||
@@ -209,6 +212,8 @@ def main():
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
|
||||||
|
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
squad_convert_examples_to_features(
|
squad_convert_examples_to_features(
|
||||||
examples=eval_examples,
|
examples=eval_examples,
|
||||||
@@ -223,6 +228,8 @@ def main():
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
|
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from enum import Enum
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -35,7 +36,11 @@ class Split(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def get_tfds(
|
def get_tfds(
|
||||||
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
|
task_name: str,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
max_seq_length: Optional[int] = None,
|
||||||
|
mode: Split = Split.train,
|
||||||
|
data_dir: str = None,
|
||||||
):
|
):
|
||||||
if task_name == "mnli-mm" and mode == Split.dev:
|
if task_name == "mnli-mm" and mode == Split.dev:
|
||||||
tfds_name = "mnli_mismatched"
|
tfds_name = "mnli_mismatched"
|
||||||
@@ -50,9 +55,11 @@ def get_tfds(
|
|||||||
else:
|
else:
|
||||||
tfds_name = task_name
|
tfds_name = task_name
|
||||||
|
|
||||||
ds = tfds.load("glue/" + tfds_name, split=mode.value)
|
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
|
||||||
|
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
||||||
|
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
|
||||||
|
|
||||||
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
|
return ds
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
||||||
|
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -171,13 +179,22 @@ def main():
|
|||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = (
|
train_dataset = (
|
||||||
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
|
get_tfds(
|
||||||
|
task_name=data_args.task_name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_length=data_args.max_seq_length,
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
|
)
|
||||||
if training_args.do_train
|
if training_args.do_train
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
get_tfds(
|
get_tfds(
|
||||||
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
|
task_name=data_args.task_name,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_length=data_args.max_seq_length,
|
||||||
|
mode=Split.dev,
|
||||||
|
data_dir=data_args.data_dir,
|
||||||
)
|
)
|
||||||
if training_args.do_eval
|
if training_args.do_eval
|
||||||
else None
|
else None
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -185,11 +184,6 @@ def main():
|
|||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for j in range(seq_len):
|
for j in range(seq_len):
|
||||||
if label_ids[i, j] == -1:
|
|
||||||
label_ids[i, j] = -100
|
|
||||||
warnings.warn(
|
|
||||||
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
|
|
||||||
)
|
|
||||||
if label_ids[i, j] != -100:
|
if label_ids[i, j] != -100:
|
||||||
out_label_list[i].append(label_map[label_ids[i][j]])
|
out_label_list[i].append(label_map[label_ids[i][j]])
|
||||||
preds_list[i].append(label_map[preds[i][j]])
|
preds_list[i].append(label_map[preds[i][j]])
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ if is_tf_available():
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
features: List[InputFeatures]
|
features: List[InputFeatures]
|
||||||
pad_token_label_id: int = -1
|
pad_token_label_id: int = -100
|
||||||
# Use cross entropy ignore_index as padding label id so that only
|
# Use cross entropy ignore_index as padding label id so that only
|
||||||
# real label ids contribute to the loss later.
|
# real label ids contribute to the loss later.
|
||||||
|
|
||||||
@@ -221,6 +221,8 @@ if is_tf_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self):
|
||||||
|
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||||
|
|
||||||
return self.dataset
|
return self.dataset
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
@@ -174,11 +173,7 @@ class TFTokenClassificationLoss:
|
|||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100
|
# make sure only labels that are not equal to -100
|
||||||
# are taken into account as loss
|
# are taken into account as loss
|
||||||
if tf.math.reduce_any(labels == -1).numpy() is True:
|
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||||
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
|
||||||
else:
|
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -100
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
"""Tensorflow trainer class."""
|
"""Tensorflow trainer class."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
from .modeling_tf_utils import TFPreTrainedModel
|
from .modeling_tf_utils import TFPreTrainedModel
|
||||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||||
@@ -21,6 +24,15 @@ if is_wandb_available():
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if parse(tf.__version__).release < (2, 2, 0):
|
||||||
|
logger.info(
|
||||||
|
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
|
||||||
|
tf.__version__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
class TFTrainer:
|
class TFTrainer:
|
||||||
"""
|
"""
|
||||||
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
|
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
|
||||||
@@ -57,7 +69,7 @@ class TFTrainer:
|
|||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||||
prediction_loss_only: bool
|
prediction_loss_only: bool
|
||||||
tb_writer: Optional[tf.summary.SummaryWriter] = None
|
tb_writer: Optional[tf.summary.SummaryWriter] = None
|
||||||
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None
|
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (None, None)
|
||||||
global_step: Optional[int] = None
|
global_step: Optional[int] = None
|
||||||
epoch_logging: Optional[float] = None
|
epoch_logging: Optional[float] = None
|
||||||
|
|
||||||
@@ -70,7 +82,10 @@ class TFTrainer:
|
|||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
prediction_loss_only=False,
|
prediction_loss_only=False,
|
||||||
tb_writer: Optional[tf.summary.SummaryWriter] = None,
|
tb_writer: Optional[tf.summary.SummaryWriter] = None,
|
||||||
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None,
|
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.args = args
|
self.args = args
|
||||||
@@ -78,7 +93,7 @@ class TFTrainer:
|
|||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.prediction_loss_only = prediction_loss_only
|
self.prediction_loss_only = prediction_loss_only
|
||||||
self.optimizers = optimizers
|
self.optimizer, self.lr_scheduler = optimizers
|
||||||
self.gradient_accumulator = GradientAccumulator()
|
self.gradient_accumulator = GradientAccumulator()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.epoch_logging = 0
|
self.epoch_logging = 0
|
||||||
@@ -105,23 +120,19 @@ class TFTrainer:
|
|||||||
if self.train_dataset is None:
|
if self.train_dataset is None:
|
||||||
raise ValueError("Trainer: training requires a train_dataset.")
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
|
|
||||||
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
|
self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
|
||||||
|
self.num_train_examples = tf.data.experimental.cardinality(self.train_dataset).numpy()
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.num_train_examples < 0:
|
||||||
self.train_steps = self.args.max_steps
|
raise ValueError("The training dataset must have an asserted cardinality")
|
||||||
else:
|
|
||||||
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
|
|
||||||
|
|
||||||
ds = (
|
ds = (
|
||||||
self.train_dataset.cache()
|
self.train_dataset.repeat()
|
||||||
.shuffle(self.num_train_examples)
|
.shuffle(self.num_train_examples, seed=self.args.seed)
|
||||||
.batch(self.args.train_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
.batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
||||||
.prefetch(tf.data.experimental.AUTOTUNE)
|
.prefetch(tf.data.experimental.AUTOTUNE)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
|
||||||
self.train_dataset = self.train_dataset.repeat(-1)
|
|
||||||
|
|
||||||
return self.args.strategy.experimental_distribute_dataset(ds)
|
return self.args.strategy.experimental_distribute_dataset(ds)
|
||||||
|
|
||||||
def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
|
def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
|
||||||
@@ -136,13 +147,20 @@ class TFTrainer:
|
|||||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||||
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
num_examples = tf.data.experimental.cardinality(eval_dataset).numpy()
|
||||||
|
|
||||||
|
if num_examples < 0:
|
||||||
|
raise ValueError("The training dataset must have an asserted cardinality")
|
||||||
|
|
||||||
|
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
||||||
|
steps = approx(num_examples / self.args.eval_batch_size)
|
||||||
ds = (
|
ds = (
|
||||||
eval_dataset.cache()
|
eval_dataset.repeat()
|
||||||
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
||||||
.prefetch(tf.data.experimental.AUTOTUNE)
|
.prefetch(tf.data.experimental.AUTOTUNE)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.args.strategy.experimental_distribute_dataset(ds)
|
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
|
||||||
|
|
||||||
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
|
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
@@ -151,11 +169,23 @@ class TFTrainer:
|
|||||||
Args:
|
Args:
|
||||||
test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
|
test_dataset (:class:`~tf.data.Dataset`): The dataset to use.
|
||||||
"""
|
"""
|
||||||
ds = test_dataset.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
|
||||||
|
|
||||||
return self.args.strategy.experimental_distribute_dataset(ds)
|
num_examples = tf.data.experimental.cardinality(test_dataset).numpy()
|
||||||
|
|
||||||
def get_optimizers(
|
if num_examples < 0:
|
||||||
|
raise ValueError("The training dataset must have an asserted cardinality")
|
||||||
|
|
||||||
|
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
||||||
|
steps = approx(num_examples / self.args.eval_batch_size)
|
||||||
|
ds = (
|
||||||
|
test_dataset.repeat()
|
||||||
|
.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
|
||||||
|
.prefetch(tf.data.experimental.AUTOTUNE)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
|
||||||
|
|
||||||
|
def create_optimizer_and_scheduler(
|
||||||
self, num_training_steps: int,
|
self, num_training_steps: int,
|
||||||
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
|
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
|
||||||
"""
|
"""
|
||||||
@@ -164,20 +194,16 @@ class TFTrainer:
|
|||||||
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
||||||
TFTrainer's init through :obj:`optimizers`, or override this method in a subclass.
|
TFTrainer's init through :obj:`optimizers`, or override this method in a subclass.
|
||||||
"""
|
"""
|
||||||
if self.optimizers is not None:
|
if not self.optimizer and not self.lr_scheduler:
|
||||||
return self.optimizers
|
self.optimizer, self.lr_scheduler = create_optimizer(
|
||||||
|
self.args.learning_rate,
|
||||||
optimizer, scheduler = create_optimizer(
|
num_training_steps,
|
||||||
self.args.learning_rate,
|
self.args.warmup_steps,
|
||||||
num_training_steps,
|
adam_beta1=self.args.adam_beta1,
|
||||||
self.args.warmup_steps,
|
adam_beta2=self.args.adam_beta2,
|
||||||
adam_beta1=self.args.adam_beta1,
|
adam_epsilon=self.args.adam_epsilon,
|
||||||
adam_beta2=self.args.adam_beta2,
|
weight_decay_rate=self.args.weight_decay,
|
||||||
adam_epsilon=self.args.adam_epsilon,
|
)
|
||||||
weight_decay_rate=self.args.weight_decay,
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer, scheduler
|
|
||||||
|
|
||||||
def _setup_wandb(self):
|
def _setup_wandb(self):
|
||||||
"""
|
"""
|
||||||
@@ -195,29 +221,13 @@ class TFTrainer:
|
|||||||
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
|
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
|
||||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
|
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _evaluate_steps(self, per_replica_features, per_replica_labels):
|
|
||||||
"""
|
|
||||||
One step evaluation across replica.
|
|
||||||
Args:
|
|
||||||
per_replica_features: the batched features.
|
|
||||||
per_replica_labels: the batched labels.
|
|
||||||
Returns:
|
|
||||||
The loss corresponding to the given batch.
|
|
||||||
"""
|
|
||||||
per_replica_loss, per_replica_logits = self.args.strategy.experimental_run_v2(
|
|
||||||
self._run_model, args=(per_replica_features, per_replica_labels, False)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
|
|
||||||
except ValueError:
|
|
||||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
|
|
||||||
|
|
||||||
return reduced_loss, per_replica_logits
|
|
||||||
|
|
||||||
def _prediction_loop(
|
def _prediction_loop(
|
||||||
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
|
self,
|
||||||
|
dataset: tf.data.Dataset,
|
||||||
|
steps: int,
|
||||||
|
num_examples: int,
|
||||||
|
description: str,
|
||||||
|
prediction_loss_only: Optional[bool] = None,
|
||||||
) -> PredictionOutput:
|
) -> PredictionOutput:
|
||||||
"""
|
"""
|
||||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||||
@@ -228,21 +238,20 @@ class TFTrainer:
|
|||||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
||||||
|
|
||||||
logger.info("***** Running %s *****", description)
|
logger.info("***** Running %s *****", description)
|
||||||
|
logger.info(" Num examples = %d", num_examples)
|
||||||
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
||||||
|
|
||||||
label_ids: np.ndarray = None
|
label_ids: np.ndarray = None
|
||||||
preds: np.ndarray = None
|
preds: np.ndarray = None
|
||||||
|
self.eval_loss = tf.keras.metrics.Sum()
|
||||||
step: int = 1
|
|
||||||
|
|
||||||
# Reset the past mems state at the beginning of the evaluation if necessary.
|
# Reset the past mems state at the beginning of the evaluation if necessary.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
for features, labels in dataset:
|
for step, batch in enumerate(dataset):
|
||||||
step = tf.convert_to_tensor(step, dtype=tf.int64)
|
logits = self.distributed_test_steps(batch)
|
||||||
loss, logits = self._evaluate_steps(features, labels)
|
_, labels = batch
|
||||||
loss = tf.reduce_mean(loss)
|
|
||||||
|
|
||||||
if not prediction_loss_only:
|
if not prediction_loss_only:
|
||||||
if isinstance(logits, tuple):
|
if isinstance(logits, tuple):
|
||||||
@@ -274,14 +283,15 @@ class TFTrainer:
|
|||||||
else:
|
else:
|
||||||
label_ids = np.append(label_ids, labels.numpy(), axis=0)
|
label_ids = np.append(label_ids, labels.numpy(), axis=0)
|
||||||
|
|
||||||
step += 1
|
if step == steps:
|
||||||
|
break
|
||||||
|
|
||||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||||
else:
|
else:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
metrics["eval_loss"] = loss.numpy()
|
metrics["eval_loss"] = self.eval_loss.result().numpy() / (steps * self.args.eval_batch_size)
|
||||||
|
|
||||||
for key in list(metrics.keys()):
|
for key in list(metrics.keys()):
|
||||||
if not key.startswith("eval_"):
|
if not key.startswith("eval_"):
|
||||||
@@ -322,9 +332,9 @@ class TFTrainer:
|
|||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
||||||
"""
|
"""
|
||||||
eval_ds = self.get_eval_tfdataset(eval_dataset)
|
eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
|
||||||
|
|
||||||
output = self._prediction_loop(eval_ds, description="Evaluation")
|
output = self._prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
|
||||||
|
|
||||||
logs = {**output.metrics}
|
logs = {**output.metrics}
|
||||||
logs["epoch"] = self.epoch_logging
|
logs["epoch"] = self.epoch_logging
|
||||||
@@ -333,6 +343,19 @@ class TFTrainer:
|
|||||||
|
|
||||||
return output.metrics
|
return output.metrics
|
||||||
|
|
||||||
|
def test_step(self, features, labels):
|
||||||
|
per_example_loss, logits = self._run_model(features, labels, False)
|
||||||
|
|
||||||
|
self.eval_loss.update_state(per_example_loss)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def distributed_test_steps(self, batch):
|
||||||
|
logits = self.args.strategy.run(self.test_step, batch)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
"""
|
"""
|
||||||
Train method to train the model.
|
Train method to train the model.
|
||||||
@@ -346,24 +369,18 @@ class TFTrainer:
|
|||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
t_total = self.args.max_steps
|
t_total = self.args.max_steps
|
||||||
steps_per_epoch = self.args.max_steps
|
self.steps_per_epoch = self.args.max_steps
|
||||||
else:
|
else:
|
||||||
if self.args.dataloader_drop_last:
|
approx = math.floor if self.args.dataloader_drop_last else math.ceil
|
||||||
approx = math.floor
|
self.steps_per_epoch = approx(self.num_train_examples / self.total_train_batch_size)
|
||||||
else:
|
t_total = self.steps_per_epoch * self.args.num_train_epochs
|
||||||
approx = math.ceil
|
|
||||||
|
|
||||||
steps_per_epoch = approx(
|
|
||||||
self.num_train_examples / (self.args.train_batch_size * self.args.gradient_accumulation_steps)
|
|
||||||
)
|
|
||||||
t_total = steps_per_epoch * self.args.num_train_epochs
|
|
||||||
|
|
||||||
with self.args.strategy.scope():
|
with self.args.strategy.scope():
|
||||||
optimizer, lr_scheduler = self.get_optimizers(num_training_steps=t_total)
|
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
||||||
iterations = optimizer.iterations
|
iterations = self.optimizer.iterations
|
||||||
self.global_step = iterations.numpy()
|
self.global_step = iterations.numpy()
|
||||||
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
|
folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
|
||||||
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
|
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
||||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
|
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
|
||||||
|
|
||||||
if self.model.ckpt_manager.latest_checkpoint:
|
if self.model.ckpt_manager.latest_checkpoint:
|
||||||
@@ -384,141 +401,138 @@ class TFTrainer:
|
|||||||
else:
|
else:
|
||||||
epochs_trained = 1
|
epochs_trained = 1
|
||||||
|
|
||||||
tf.summary.experimental.set_step(iterations)
|
tf.summary.experimental.set_step(iterations)
|
||||||
|
|
||||||
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
|
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
|
||||||
|
|
||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
||||||
tf.keras.mixed_precision.experimental.set_policy(policy)
|
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||||
|
|
||||||
with self.tb_writer.as_default():
|
with self.tb_writer.as_default():
|
||||||
tf.summary.text("args", self.args.to_json_string())
|
tf.summary.text("args", self.args.to_json_string())
|
||||||
|
|
||||||
self.tb_writer.flush()
|
self.tb_writer.flush()
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", self.num_train_examples)
|
logger.info(" Num examples = %d", self.num_train_examples)
|
||||||
logger.info(" Num Epochs = %d", epochs)
|
logger.info(" Num Epochs = %d", epochs)
|
||||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||||
logger.info(
|
logger.info(
|
||||||
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size
|
" Total train batch size (w. parallel, distributed & accumulation) = %d", self.total_train_batch_size
|
||||||
)
|
)
|
||||||
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Steps per epoch = %d", self.steps_per_epoch)
|
||||||
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
self.train_loss = tf.keras.metrics.Sum()
|
||||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
start_time = datetime.datetime.now()
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = None
|
|
||||||
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
|
|
||||||
self.global_step = iterations.numpy()
|
|
||||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
|
|
||||||
|
|
||||||
if self.args.debug:
|
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
||||||
logs = {}
|
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||||
logs["loss"] = training_loss.numpy()
|
if self.args.past_index >= 0:
|
||||||
logs["epoch"] = self.epoch_logging
|
self._past = None
|
||||||
|
|
||||||
self._log(logs)
|
for step, batch in enumerate(train_ds):
|
||||||
|
self.global_step = iterations.numpy()
|
||||||
|
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.steps_per_epoch
|
||||||
|
|
||||||
if self.global_step == 1 and self.args.debug:
|
self.distributed_training_steps(batch)
|
||||||
with self.tb_writer.as_default():
|
|
||||||
tf.summary.trace_export(
|
|
||||||
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
training_loss = self.train_loss.result() / ((step + 1) * self.total_train_batch_size)
|
||||||
self.evaluate()
|
|
||||||
|
|
||||||
if (
|
if self.args.debug:
|
||||||
self.global_step % self.args.logging_steps == 0
|
logs = {}
|
||||||
or self.global_step == 1
|
logs["loss"] = training_loss.numpy()
|
||||||
and self.args.logging_first_step
|
logs["epoch"] = self.epoch_logging
|
||||||
):
|
|
||||||
logs = {}
|
|
||||||
logs["loss"] = training_loss.numpy()
|
|
||||||
logs["learning_rate"] = lr_scheduler(self.global_step).numpy()
|
|
||||||
logs["epoch"] = self.epoch_logging
|
|
||||||
|
|
||||||
self._log(logs)
|
self._log(logs)
|
||||||
|
|
||||||
if self.global_step % self.args.save_steps == 0:
|
if self.global_step == 1 and self.args.debug:
|
||||||
ckpt_save_path = self.model.ckpt_manager.save()
|
with self.tb_writer.as_default():
|
||||||
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
|
tf.summary.trace_export(
|
||||||
|
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
|
if (
|
||||||
break
|
self.global_step > 0
|
||||||
|
and self.args.evaluate_during_training
|
||||||
|
and self.global_step % self.args.eval_steps == 0
|
||||||
|
):
|
||||||
|
self.evaluate()
|
||||||
|
|
||||||
|
if (self.global_step > 0 and self.global_step % self.args.logging_steps == 0) or (
|
||||||
|
self.global_step == 1 and self.args.logging_first_step
|
||||||
|
):
|
||||||
|
logs = {}
|
||||||
|
logs["loss"] = training_loss.numpy()
|
||||||
|
logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
|
||||||
|
logs["epoch"] = self.epoch_logging
|
||||||
|
|
||||||
|
self._log(logs)
|
||||||
|
|
||||||
|
if self.global_step > 0 and self.global_step % self.args.save_steps == 0:
|
||||||
|
ckpt_save_path = self.model.ckpt_manager.save()
|
||||||
|
|
||||||
|
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
|
||||||
|
|
||||||
|
if self.global_step > 0 and self.global_step % self.steps_per_epoch == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.train_loss.reset_states()
|
||||||
|
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
logger.info("Training took: {}".format(str(end_time - start_time)))
|
||||||
|
|
||||||
if self.args.past_index and hasattr(self, "_past"):
|
if self.args.past_index and hasattr(self, "_past"):
|
||||||
# Clean the state at the end of training
|
# Clean the state at the end of training
|
||||||
delattr(self, "_past")
|
delattr(self, "_past")
|
||||||
|
|
||||||
def _training_steps(self, ds, optimizer):
|
def training_step(self, features, labels):
|
||||||
"""
|
|
||||||
Returns a generator over training steps (i.e. parameters update).
|
|
||||||
"""
|
|
||||||
for i, loss in enumerate(self._accumulate_next_gradients(ds)):
|
|
||||||
if i % self.args.gradient_accumulation_steps == 0:
|
|
||||||
self._apply_gradients(optimizer)
|
|
||||||
yield loss
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _apply_gradients(self, optimizer):
|
|
||||||
"""Applies the gradients (cross-replica)."""
|
|
||||||
self.args.strategy.experimental_run_v2(self._step, args=(optimizer,))
|
|
||||||
|
|
||||||
def _step(self, optimizer):
|
|
||||||
"""Applies gradients and resets accumulation."""
|
|
||||||
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
|
|
||||||
gradients = [
|
|
||||||
gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients
|
|
||||||
]
|
|
||||||
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
|
|
||||||
|
|
||||||
optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
|
||||||
self.gradient_accumulator.reset()
|
|
||||||
|
|
||||||
def _accumulate_next_gradients(self, ds):
|
|
||||||
"""Accumulates the gradients from the next element in dataset."""
|
|
||||||
iterator = iter(ds)
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _accumulate_next():
|
|
||||||
per_replica_features, per_replica_labels = next(iterator)
|
|
||||||
|
|
||||||
return self._accumulate_gradients(per_replica_features, per_replica_labels)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
yield _accumulate_next()
|
|
||||||
except tf.errors.OutOfRangeError:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _accumulate_gradients(self, per_replica_features, per_replica_labels):
|
|
||||||
"""Accumulates the gradients across all the replica."""
|
|
||||||
per_replica_loss = self.args.strategy.experimental_run_v2(
|
|
||||||
self._forward, args=(per_replica_features, per_replica_labels)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0)
|
|
||||||
except ValueError:
|
|
||||||
reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None)
|
|
||||||
|
|
||||||
return reduced_loss
|
|
||||||
|
|
||||||
def _forward(self, features, labels):
|
|
||||||
"""Forwards a training example and accumulates the gradients."""
|
|
||||||
per_example_loss, _ = self._run_model(features, labels, True)
|
per_example_loss, _ = self._run_model(features, labels, True)
|
||||||
gradients = tf.gradients(per_example_loss, self.model.trainable_variables)
|
scaled_loss = per_example_loss / self.total_train_batch_size
|
||||||
|
gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
|
||||||
gradients = [
|
gradients = [
|
||||||
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
|
g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.gradient_accumulator(gradients)
|
if self.args.gradient_accumulation_steps > 1:
|
||||||
|
self.gradient_accumulator(gradients)
|
||||||
|
|
||||||
return per_example_loss
|
self.train_loss.update_state(per_example_loss)
|
||||||
|
|
||||||
|
if self.args.gradient_accumulation_steps == 1:
|
||||||
|
return gradients
|
||||||
|
|
||||||
|
def apply_gradients(self, features, labels):
|
||||||
|
if self.args.gradient_accumulation_steps == 1:
|
||||||
|
gradients = self.training_step(features, labels)
|
||||||
|
|
||||||
|
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
||||||
|
else:
|
||||||
|
for _ in tf.range(self.args.gradient_accumulation_steps):
|
||||||
|
reduced_features = features[: self.args.train_batch_size / self.args.n_replicas]
|
||||||
|
reduced_labels = labels[: self.args.train_batch_size / self.args.n_replicas]
|
||||||
|
|
||||||
|
self.training_step(reduced_features, reduced_labels)
|
||||||
|
|
||||||
|
features = tf.concat(
|
||||||
|
[features[self.args.train_batch_size / self.args.n_replicas :], reduced_features], axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
gradients = self.gradient_accumulator.gradients
|
||||||
|
gradients = [
|
||||||
|
(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients
|
||||||
|
]
|
||||||
|
|
||||||
|
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
||||||
|
self.gradient_accumulator.reset()
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def distributed_training_steps(self, batch):
|
||||||
|
with self.args.strategy.scope():
|
||||||
|
self.args.strategy.run(self.apply_gradients, batch)
|
||||||
|
|
||||||
def _run_model(self, features, labels, training):
|
def _run_model(self, features, labels, training):
|
||||||
"""
|
"""
|
||||||
@@ -530,14 +544,16 @@ class TFTrainer:
|
|||||||
"""
|
"""
|
||||||
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
|
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
|
||||||
features["mems"] = self._past
|
features["mems"] = self._past
|
||||||
|
|
||||||
if isinstance(labels, (dict)):
|
if isinstance(labels, (dict)):
|
||||||
outputs = self.model(features, training=training, **labels)[:2]
|
outputs = self.model(features, training=training, **labels)[:2]
|
||||||
else:
|
else:
|
||||||
outputs = self.model(features, labels=labels, training=training)[:2]
|
outputs = self.model(features, labels=labels, training=training)[:2]
|
||||||
|
|
||||||
loss, logits = outputs[:2]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = outputs[self.args.past_index]
|
self._past = outputs[self.args.past_index]
|
||||||
loss += sum(self.model.losses) * (1.0 / self.args.n_replicas)
|
|
||||||
|
|
||||||
return loss, logits
|
return loss, logits
|
||||||
|
|
||||||
@@ -560,9 +576,9 @@ class TFTrainer:
|
|||||||
metrics (:obj:`Dict[str, float]`, `optional`):
|
metrics (:obj:`Dict[str, float]`, `optional`):
|
||||||
The potential dictionary of metrics (if the dataset contained labels).
|
The potential dictionary of metrics (if the dataset contained labels).
|
||||||
"""
|
"""
|
||||||
test_ds = self.get_test_tfdataset(test_dataset)
|
test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
|
||||||
|
|
||||||
return self._prediction_loop(test_ds, description="Prediction")
|
return self._prediction_loop(test_ds, steps, num_examples, description="Prediction")
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None):
|
def save_model(self, output_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
"version. Using `--per_device_train_batch_size` is preferred."
|
"version. Using `--per_device_train_batch_size` is preferred."
|
||||||
)
|
)
|
||||||
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
|
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
|
||||||
return per_device_batch_size * max(1, self.n_replicas)
|
return per_device_batch_size * self.n_replicas
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eval_batch_size(self) -> int:
|
def eval_batch_size(self) -> int:
|
||||||
@@ -175,7 +175,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
"version. Using `--per_device_eval_batch_size` is preferred."
|
"version. Using `--per_device_eval_batch_size` is preferred."
|
||||||
)
|
)
|
||||||
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
|
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
|
||||||
return per_device_batch_size * max(1, self.n_replicas)
|
return per_device_batch_size * self.n_replicas
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
@tf_required
|
||||||
|
|||||||
Reference in New Issue
Block a user