[pl_examples] default warmup steps=0 (#5316)

This commit is contained in:
Sam Shleifer
2020-06-26 15:03:41 -04:00
committed by GitHub
parent bf0d12c220
commit 5543b30aa6
6 changed files with 14 additions and 13 deletions

View File

@@ -122,12 +122,9 @@ class BaseTransformer(pl.LightningModule):
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
self.lr_scheduler.step() # By default, PL will only step every epoch.
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
self.logger.log_metrics(lrs)
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
@@ -202,7 +199,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."