Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
|
|||||||
def generic_train(
|
def generic_train(
|
||||||
model: BaseTransformer,
|
model: BaseTransformer,
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
early_stopping_callback=False,
|
early_stopping_callback=None,
|
||||||
logger=True, # can pass WandbLogger() here
|
logger=True, # can pass WandbLogger() here
|
||||||
extra_callbacks=[],
|
extra_callbacks=[],
|
||||||
checkpoint_callback=None,
|
checkpoint_callback=None,
|
||||||
@@ -355,6 +355,8 @@ def generic_train(
|
|||||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||||
)
|
)
|
||||||
|
if early_stopping_callback:
|
||||||
|
extra_callbacks.append(early_stopping_callback)
|
||||||
if logging_callback is None:
|
if logging_callback is None:
|
||||||
logging_callback = LoggingCallback()
|
logging_callback = LoggingCallback()
|
||||||
|
|
||||||
@@ -376,7 +378,6 @@ def generic_train(
|
|||||||
callbacks=[logging_callback] + extra_callbacks,
|
callbacks=[logging_callback] + extra_callbacks,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
early_stop_callback=early_stopping_callback,
|
|
||||||
**train_params,
|
**train_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ psutil
|
|||||||
sacrebleu
|
sacrebleu
|
||||||
rouge-score
|
rouge-score
|
||||||
tensorflow_datasets
|
tensorflow_datasets
|
||||||
pytorch-lightning==0.9.0
|
pytorch-lightning==1.0.4
|
||||||
matplotlib
|
matplotlib
|
||||||
git-python==1.0.3
|
git-python==1.0.3
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
|
|||||||
@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
|
|||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="min" if "loss" in metric else "max",
|
mode="min" if "loss" in metric else "max",
|
||||||
save_top_k=save_top_k,
|
save_top_k=save_top_k,
|
||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
|
||||||
)
|
)
|
||||||
return checkpoint_callback
|
return checkpoint_callback
|
||||||
|
|
||||||
|
|||||||
@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
return self._generative_step(batch)
|
return self._generative_step(batch)
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||||
|
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||||
loss = losses["loss"]
|
loss = losses["loss"]
|
||||||
@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||||
dataset = self.get_dataset(type_path)
|
dataset = self.get_dataset(type_path)
|
||||||
|
|
||||||
if self.hparams.sortish_sampler and type_path != "test":
|
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
|
||||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
|
||||||
batch_sampler = dataset.make_dynamic_sampler(
|
batch_sampler = dataset.make_dynamic_sampler(
|
||||||
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -144,6 +144,7 @@ class TestAll(TestCasePlus):
|
|||||||
f"--num_train_epochs={epochs}",
|
f"--num_train_epochs={epochs}",
|
||||||
"--warmup_steps=10",
|
"--warmup_steps=10",
|
||||||
"--val_check_interval=1.0",
|
"--val_check_interval=1.0",
|
||||||
|
"--do_predict",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
@@ -151,7 +152,6 @@ class TestAll(TestCasePlus):
|
|||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.do_predict = False
|
|
||||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||||
|
|
||||||
model = distill_main(args)
|
model = distill_main(args)
|
||||||
|
|||||||
@@ -176,7 +176,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
|||||||
print(metrics)
|
print(metrics)
|
||||||
last_step_stats = metrics["val"][-1]
|
last_step_stats = metrics["val"][-1]
|
||||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
|
||||||
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
|
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
|
||||||
self.assertEqual(len(metrics["test"]), 1)
|
self.assertEqual(len(metrics["test"]), 1)
|
||||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
|
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ def main():
|
|||||||
|
|
||||||
# Optionally, predict on dev set and write to output_dir
|
# Optionally, predict on dev set and write to output_dir
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
|
||||||
model = model.load_from_checkpoint(checkpoints[-1])
|
model = model.load_from_checkpoint(checkpoints[-1])
|
||||||
return trainer.test(model)
|
return trainer.test(model)
|
||||||
|
|
||||||
|
|||||||
@@ -207,9 +207,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
# See https://github.com/huggingface/transformers/issues/3159
|
# See https://github.com/huggingface/transformers/issues/3159
|
||||||
# pl use this format to create a checkpoint:
|
# pl use this default format to create a checkpoint:
|
||||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
|
||||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
|
||||||
model = model.load_from_checkpoint(checkpoints[-1])
|
model = model.load_from_checkpoint(checkpoints[-1])
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user