Floating-point operations logging in trainer (#6768)
* neFLOs calculation, logging, and reloading (#1) * testing distributed consecutive batches * fixed AttributeError from DataParallel * removed verbosity * rotate with use_mtime=True * removed print * fixed interaction with gradient accumulation * indent formatting * distributed neflo counting * fixed typo * fixed typo * mean distributed losses * exporting log history * moved a few functions * floating_point_ops clarification for transformers with parameter-reuse * code quality * double import * made flo estimation more task-agnostic * only logging flos if computed * code quality * unused import * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Sylvain review * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * black Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,9 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
@@ -45,7 +46,6 @@ from .utils import logging
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn import Identity
|
from torch.nn import Identity
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -91,20 +91,6 @@ class ModuleUtilsMixin:
|
|||||||
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
|
||||||
"""
|
|
||||||
Get the number of (optionally, trainable) parameters in the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
|
||||||
Whether or not to return only the number of trainable parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`int`: The number of parameters.
|
|
||||||
"""
|
|
||||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
|
||||||
return sum(p.numel() for p in params)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
@@ -307,9 +293,77 @@ class ModuleUtilsMixin:
|
|||||||
elif head_mask.dim() == 2:
|
elif head_mask.dim() == 2:
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||||
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
|
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
|
||||||
return head_mask
|
return head_mask
|
||||||
|
|
||||||
|
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
||||||
|
"""
|
||||||
|
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to return only the number of trainable parameters
|
||||||
|
|
||||||
|
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to return only the number of non-embeddings parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`int`: The number of parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parameter_filter(x):
|
||||||
|
return (x.requires_grad or not only_trainable) and not (
|
||||||
|
isinstance(x, torch.nn.Embedding) and exclude_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
|
||||||
|
return sum(p.numel() for p in params)
|
||||||
|
|
||||||
|
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||||
|
"""
|
||||||
|
Helper function to estimate the total number of tokens from the model inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (:obj:`dict`): The model inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`int`: The total number of tokens.
|
||||||
|
"""
|
||||||
|
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
|
||||||
|
if token_inputs:
|
||||||
|
return sum([token_input.numel() for token_input in token_inputs])
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def floating_point_ops(
|
||||||
|
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
|
||||||
|
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
|
||||||
|
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
|
||||||
|
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
|
||||||
|
if doing long-range modeling with very high sequence lengths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (:obj:`int`):
|
||||||
|
The batch size for the forward pass.
|
||||||
|
|
||||||
|
sequence_length (:obj:`int`):
|
||||||
|
The number of tokens in each line of the batch.
|
||||||
|
|
||||||
|
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to count embedding and softmax operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`int`: The number of floating-point operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -42,6 +43,8 @@ from .trainer_utils import (
|
|||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
default_hp_space,
|
||||||
|
distributed_broadcast_scalars,
|
||||||
|
distributed_concat,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler):
|
|||||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||||
assert (
|
assert (
|
||||||
len(indices) == self.num_samples
|
len(indices) == self.num_samples
|
||||||
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
|
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||||
|
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
|
|
||||||
@@ -241,6 +244,7 @@ class Trainer:
|
|||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
|
self.log_history = []
|
||||||
if "prediction_loss_only" in kwargs:
|
if "prediction_loss_only" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
||||||
@@ -284,6 +288,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.global_step = None
|
self.global_step = None
|
||||||
self.epoch = None
|
self.epoch = None
|
||||||
|
self.total_flos = None
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.args.fp16 and _use_native_amp:
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
self.hp_search_backend = None
|
self.hp_search_backend = None
|
||||||
@@ -461,7 +466,11 @@ class Trainer:
|
|||||||
logger.info(
|
logger.info(
|
||||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||||
|
except AttributeError:
|
||||||
|
# in case the model has no config
|
||||||
|
combined_dict = {**self.args.to_sanitized_dict()}
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
||||||
)
|
)
|
||||||
@@ -663,6 +672,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
self.total_flos = 0
|
||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
@@ -670,6 +680,8 @@ class Trainer:
|
|||||||
# set global_step to global_step of last saved checkpoint from model path
|
# set global_step to global_step of last saved checkpoint from model path
|
||||||
try:
|
try:
|
||||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||||
|
self.total_flos = getattr(model.config, "total_flos", 0)
|
||||||
|
|
||||||
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
|
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
|
||||||
steps_trained_in_current_epoch = self.global_step % (
|
steps_trained_in_current_epoch = self.global_step % (
|
||||||
len(train_dataloader) // self.args.gradient_accumulation_steps
|
len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||||
@@ -678,9 +690,11 @@ class Trainer:
|
|||||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||||
logger.info(" Continuing training from global step %d", self.global_step)
|
logger.info(" Continuing training from global step %d", self.global_step)
|
||||||
|
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
|
||||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
self.total_flos = 0
|
||||||
logger.info(" Starting fine-tuning.")
|
logger.info(" Starting fine-tuning.")
|
||||||
|
|
||||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||||
@@ -714,6 +728,7 @@ class Trainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss += self.training_step(model, inputs)
|
||||||
|
self.total_flos += self.floating_point_ops(inputs)
|
||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
@@ -784,7 +799,7 @@ class Trainer:
|
|||||||
self.save_model(output_dir)
|
self.save_model(output_dir)
|
||||||
|
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
self._rotate_checkpoints()
|
self._rotate_checkpoints(use_mtime=True)
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous("saving_optimizer_states")
|
xm.rendezvous("saving_optimizer_states")
|
||||||
@@ -924,6 +939,13 @@ class Trainer:
|
|||||||
|
|
||||||
if self.epoch is not None:
|
if self.epoch is not None:
|
||||||
logs["epoch"] = self.epoch
|
logs["epoch"] = self.epoch
|
||||||
|
if self.total_flos is not None:
|
||||||
|
if self.args.local_rank != -1:
|
||||||
|
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||||
|
else:
|
||||||
|
total_flos = self.total_flos
|
||||||
|
if total_flos > 0:
|
||||||
|
logs["total_flos"] = self.total_flos
|
||||||
if self.global_step is None:
|
if self.global_step is None:
|
||||||
# when logging evaluation metrics without training
|
# when logging evaluation metrics without training
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
@@ -951,6 +973,8 @@ class Trainer:
|
|||||||
if experiment is not None:
|
if experiment is not None:
|
||||||
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
||||||
output = {**logs, **{"step": self.global_step}}
|
output = {**logs, **{"step": self.global_step}}
|
||||||
|
if self.is_world_process_zero():
|
||||||
|
self.log_history.append(output)
|
||||||
if iterator is not None:
|
if iterator is not None:
|
||||||
iterator.write(output)
|
iterator.write(output)
|
||||||
else:
|
else:
|
||||||
@@ -1089,6 +1113,9 @@ class Trainer:
|
|||||||
if xm.is_master_ordinal():
|
if xm.is_master_ordinal():
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||||
|
json.dump(
|
||||||
|
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
# Save a trained model and configuration using `save_pretrained()`.
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
@@ -1096,6 +1123,7 @@ class Trainer:
|
|||||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||||
|
|
||||||
xm.rendezvous("saving_checkpoint")
|
xm.rendezvous("saving_checkpoint")
|
||||||
|
self._store_flos()
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir)
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
@@ -1108,12 +1136,26 @@ class Trainer:
|
|||||||
# They can then be reloaded using `from_pretrained()`
|
# They can then be reloaded using `from_pretrained()`
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||||
|
self._store_flos()
|
||||||
self.model.save_pretrained(output_dir)
|
self.model.save_pretrained(output_dir)
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||||
|
json.dump(
|
||||||
|
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _store_flos(self):
|
||||||
|
# Storing the number of floating-point operations that went into the model
|
||||||
|
if self.total_flos is not None:
|
||||||
|
if self.args.local_rank != -1:
|
||||||
|
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||||
|
else:
|
||||||
|
total_flos = self.total_flos
|
||||||
|
if total_flos > 0:
|
||||||
|
self.model.config.total_flos = total_flos
|
||||||
|
|
||||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||||
ordering_and_checkpoint_path = []
|
ordering_and_checkpoint_path = []
|
||||||
@@ -1245,13 +1287,11 @@ class Trainer:
|
|||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
||||||
samples_count = 0
|
|
||||||
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
||||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||||
samples_count += batch_size
|
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
eval_losses.append(loss * batch_size)
|
eval_losses.extend([loss] * batch_size)
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -1264,9 +1304,9 @@ class Trainer:
|
|||||||
if self.args.local_rank != -1:
|
if self.args.local_rank != -1:
|
||||||
# In distributed mode, concatenate all results from all nodes:
|
# In distributed mode, concatenate all results from all nodes:
|
||||||
if preds is not None:
|
if preds is not None:
|
||||||
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||||
if label_ids is not None:
|
if label_ids is not None:
|
||||||
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||||
if preds is not None:
|
if preds is not None:
|
||||||
@@ -1289,7 +1329,14 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
if len(eval_losses) > 0:
|
if len(eval_losses) > 0:
|
||||||
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
|
if self.args.local_rank != -1:
|
||||||
|
metrics["eval_loss"] = (
|
||||||
|
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
metrics["eval_loss"] = np.mean(eval_losses)
|
||||||
|
|
||||||
# Prefix all keys with eval_
|
# Prefix all keys with eval_
|
||||||
for key in list(metrics.keys()):
|
for key in list(metrics.keys()):
|
||||||
@@ -1298,18 +1345,6 @@ class Trainer:
|
|||||||
|
|
||||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||||
|
|
||||||
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
|
|
||||||
assert self.args.local_rank != -1
|
|
||||||
|
|
||||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
|
||||||
torch.distributed.all_gather(output_tensors, tensor)
|
|
||||||
|
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
|
||||||
|
|
||||||
# truncate the dummy elements added by SequentialDistributedSampler
|
|
||||||
output = concat[:num_total_examples]
|
|
||||||
return output
|
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
@@ -1355,3 +1390,32 @@ class Trainer:
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = labels.detach()
|
labels = labels.detach()
|
||||||
return (loss, logits.detach(), labels)
|
return (loss, logits.detach(), labels)
|
||||||
|
|
||||||
|
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
||||||
|
"""
|
||||||
|
For models that inherit from :class:`~transformers.PretrainedModel`, uses
|
||||||
|
that method to compute the number of floating point operations for every backward + forward pass. If using
|
||||||
|
another model, either implement such a method in the model or subclass and override this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to evaluate.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`int`: The number of floating-point operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
||||||
|
self.model, torch.nn.parallel.DistributedDataParallel
|
||||||
|
):
|
||||||
|
model = self.model.module
|
||||||
|
else:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
if hasattr(model, "floating_point_ops"):
|
||||||
|
return model.floating_point_ops(inputs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Any, Dict, NamedTuple, Optional
|
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from .file_utils import is_tf_available, is_torch_available
|
from .file_utils import is_tf_available, is_torch_available
|
||||||
from .tokenization_utils_base import ExplicitEnum
|
from .tokenization_utils_base import ExplicitEnum
|
||||||
@@ -126,3 +127,32 @@ default_hp_space = {
|
|||||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||||
HPSearchBackend.RAY: default_hp_space_ray,
|
HPSearchBackend.RAY: default_hp_space_ray,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor:
|
||||||
|
assert self.args.local_rank != -1
|
||||||
|
|
||||||
|
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||||
|
torch.distributed.all_gather(output_tensors, tensor)
|
||||||
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
|
||||||
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
|
if num_total_examples is not None:
|
||||||
|
concat = concat[:num_total_examples]
|
||||||
|
return concat
|
||||||
|
|
||||||
|
|
||||||
|
def distributed_broadcast_scalars(
|
||||||
|
self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert self.args.local_rank != -1
|
||||||
|
|
||||||
|
tensorized_scalar = torch.Tensor(scalars).cuda()
|
||||||
|
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||||
|
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||||
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
|
||||||
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
|
if num_total_examples is not None:
|
||||||
|
concat = concat[:num_total_examples]
|
||||||
|
return concat
|
||||||
|
|||||||
Reference in New Issue
Block a user