feat(wandb): save model as artifact (#8119)

* feat(wandb): log artifacts

* fix: typo

* feat(wandb): ensure name is allowed

* feat(wandb): log artifact

* feat(wandb): saving logic

* style: improve formatting

* fix: unrelated typo

* feat: use a fake trainer

* fix: simplify

* feat(wandb): log model files as artifact

* style: fix style

* docs(wandb): correct description

* feat: unpack model + allow env Truethy values

* feat: TrainerCallback can access tokenizer

* style: fix style

* feat(wandb): log more interesting metadata

* feat: unpack tokenizer

* feat(wandb): metadata with load_best_model_at_end

* feat(wandb): more robust metadata

* style(wandb): fix formatting
This commit is contained in:
Boris Dayma
2021-01-05 02:30:46 -06:00
committed by GitHub
parent 143289dcf7
commit 30fa0b780f
3 changed files with 48 additions and 3 deletions

View File

@@ -168,6 +168,8 @@ class TrainerCallback:
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
The model being trained.
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for encoding the data.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
@@ -274,11 +276,12 @@ class TrainerCallback:
class CallbackHandler(TrainerCallback):
""" Internal class that just calls the list of callbacks in order. """
def __init__(self, callbacks, model, optimizer, lr_scheduler):
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.tokenizer = tokenizer
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = None
@@ -376,6 +379,7 @@ class CallbackHandler(TrainerCallback):
state,
control,
model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,