Lightning Updates for v0.8.5 (#5798)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Nathan Raw
2020-07-17 20:43:06 -06:00
committed by GitHub
parent 615be03f9d
commit 529850ae7b
7 changed files with 73 additions and 97 deletions

View File

@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"logger": "default",
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
@@ -48,7 +48,7 @@ CHEAP_ARGS = {
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
@@ -60,7 +60,7 @@ CHEAP_ARGS = {
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
max_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model