feat(wandb): save model as artifact (#8119)
* feat(wandb): log artifacts * fix: typo * feat(wandb): ensure name is allowed * feat(wandb): log artifact * feat(wandb): saving logic * style: improve formatting * fix: unrelated typo * feat: use a fake trainer * fix: simplify * feat(wandb): log model files as artifact * style: fix style * docs(wandb): correct description * feat: unpack model + allow env Truethy values * feat: TrainerCallback can access tokenizer * style: fix style * feat(wandb): log more interesting metadata * feat: unpack tokenizer * feat(wandb): metadata with load_best_model_at_end * feat(wandb): more robust metadata * style(wandb): fix formatting
This commit is contained in:
@@ -168,6 +168,8 @@ class TrainerCallback:
|
||||
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
|
||||
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
|
||||
The model being trained.
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for encoding the data.
|
||||
optimizer (:obj:`torch.optim.Optimizer`):
|
||||
The optimizer used for the training steps.
|
||||
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
||||
@@ -274,11 +276,12 @@ class TrainerCallback:
|
||||
class CallbackHandler(TrainerCallback):
|
||||
""" Internal class that just calls the list of callbacks in order. """
|
||||
|
||||
def __init__(self, callbacks, model, optimizer, lr_scheduler):
|
||||
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
|
||||
self.callbacks = []
|
||||
for cb in callbacks:
|
||||
self.add_callback(cb)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_dataloader = None
|
||||
@@ -376,6 +379,7 @@ class CallbackHandler(TrainerCallback):
|
||||
state,
|
||||
control,
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
train_dataloader=self.train_dataloader,
|
||||
|
||||
Reference in New Issue
Block a user