Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
|
||||
monitor=f"val_{metric}",
|
||||
mode="min" if "loss" in metric else "max",
|
||||
save_top_k=save_top_k,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
|
||||
if self.hparams.sortish_sampler and type_path != "test":
|
||||
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
|
||||
batch_sampler = dataset.make_dynamic_sampler(
|
||||
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
||||
)
|
||||
|
||||
@@ -144,6 +144,7 @@ class TestAll(TestCasePlus):
|
||||
f"--num_train_epochs={epochs}",
|
||||
"--warmup_steps=10",
|
||||
"--val_check_interval=1.0",
|
||||
"--do_predict",
|
||||
]
|
||||
)
|
||||
with patch.object(sys, "argv", testargs):
|
||||
@@ -151,7 +152,6 @@ class TestAll(TestCasePlus):
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
args.do_predict = False
|
||||
# assert args.gpus == gpus THIS BREAKS for multigpu
|
||||
|
||||
model = distill_main(args)
|
||||
|
||||
@@ -176,7 +176,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
print(metrics)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
|
||||
|
||||
Reference in New Issue
Block a user