[cleanup] assign todos, faster bart-cnn test (#7835)
* 2 beam output * unassign/remove TODOs * remove one more
This commit is contained in:
@@ -291,7 +291,8 @@ class LoggingCallback(pl.Callback):
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
||||
# To allow all pl args uncomment the following line
|
||||
# parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
|
||||
@@ -21,7 +21,6 @@ from utils import load_json
|
||||
|
||||
|
||||
MODEL_NAME = MBART_TINY
|
||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
@@ -99,7 +98,7 @@ def test_train_mbart_cc25_enro_script():
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
# TODO(SS): turn on args.do_predict when PL bug fixed.
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
@@ -178,7 +177,7 @@ def test_opus_mt_distill_script():
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
# TODO(SS): turn on args.do_predict when PL bug fixed.
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
|
||||
@@ -25,7 +25,6 @@ def test_finetune_trainer():
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow():
|
||||
# TODO(SS): This will fail on devices with more than 1 GPU.
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
|
||||
Reference in New Issue
Block a user