Remove config assumption in Trainer (#7464)
* Remove config assumption in Trainer * Initialize for eval
This commit is contained in:
@@ -282,7 +282,7 @@ class Trainer:
|
||||
# Create output directory if needed
|
||||
if self.is_world_process_zero():
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
if is_torch_tpu_available():
|
||||
if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
|
||||
# Set an xla_device flag on the model's config.
|
||||
# We'll find a more elegant and not need to do this in the future.
|
||||
self.model.config.xla_device = True
|
||||
@@ -490,11 +490,9 @@ class Trainer:
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
try:
|
||||
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||
except AttributeError:
|
||||
# in case the model has no config
|
||||
combined_dict = {**self.args.to_sanitized_dict()}
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
combined_dict = {**self.model.config.to_dict(), **combined_dict}
|
||||
wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
||||
)
|
||||
@@ -533,6 +531,7 @@ class Trainer:
|
||||
if experiment is not None:
|
||||
experiment._set_model_graph(self.model, framework="transformers")
|
||||
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
|
||||
|
||||
def num_examples(self, dataloader: DataLoader) -> int:
|
||||
@@ -679,7 +678,11 @@ class Trainer:
|
||||
model,
|
||||
device_ids=[self.args.local_rank],
|
||||
output_device=self.args.local_rank,
|
||||
find_unused_parameters=not getattr(model.config, "gradient_checkpointing", False),
|
||||
find_unused_parameters=(
|
||||
not getattr(model.config, "gradient_checkpointing", False)
|
||||
if isinstance(model, PreTrainedModel)
|
||||
else True
|
||||
),
|
||||
)
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
@@ -707,15 +710,14 @@ class Trainer:
|
||||
|
||||
self.global_step = 0
|
||||
self.epoch = 0
|
||||
self.total_flos = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
# Check if continuing training from a checkpoint
|
||||
if model_path is not None:
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||
self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0)
|
||||
|
||||
epochs_trained = self.global_step // num_update_steps_per_epoch
|
||||
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
||||
@@ -723,14 +725,13 @@ class Trainer:
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", self.global_step)
|
||||
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
except ValueError:
|
||||
self.global_step = 0
|
||||
self.total_flos = 0
|
||||
logger.info(" Starting fine-tuning.")
|
||||
|
||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
self.total_flos = self.state.total_flos
|
||||
logging_loss_scalar = 0.0
|
||||
model.zero_grad()
|
||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||
@@ -1029,7 +1030,7 @@ class Trainer:
|
||||
else:
|
||||
total_flos = self.total_flos
|
||||
if total_flos > 0:
|
||||
logs["total_flos"] = self.total_flos
|
||||
logs["total_flos"] = total_flos
|
||||
if self.global_step is None:
|
||||
# when logging evaluation metrics without training
|
||||
self.global_step = 0
|
||||
@@ -1245,11 +1246,9 @@ class Trainer:
|
||||
# Storing the number of floating-point operations that went into the model
|
||||
if self.total_flos is not None:
|
||||
if self.args.local_rank != -1:
|
||||
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||
self.state.total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||
else:
|
||||
total_flos = self.total_flos
|
||||
if total_flos > 0:
|
||||
self.model.config.total_flos = total_flos
|
||||
self.state.total_flos = self.total_flos
|
||||
|
||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||
ordering_and_checkpoint_path = []
|
||||
@@ -1363,13 +1362,6 @@ class Trainer:
|
||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||
)
|
||||
|
||||
assert not getattr(
|
||||
self.model.config, "output_attentions", False
|
||||
), "The prediction loop does not work with `output_attentions=True`."
|
||||
assert not getattr(
|
||||
self.model.config, "output_hidden_states", False
|
||||
), "The prediction loop does not work with `output_hidden_states=True`."
|
||||
|
||||
model = self.model
|
||||
# multi-gpu eval
|
||||
if self.args.n_gpu > 1:
|
||||
|
||||
@@ -224,6 +224,7 @@ class TrainerState:
|
||||
A class containing the `Trainer` fields that will be saved along the model and optimizer.
|
||||
"""
|
||||
|
||||
total_flos: int = 0
|
||||
best_metric: Optional[float] = None
|
||||
best_model_checkpoint: Optional[str] = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user