[cleanup] assign todos, faster bart-cnn test (#7835)
* 2 beam output * unassign/remove TODOs * remove one more
This commit is contained in:
@@ -291,7 +291,8 @@ class LoggingCallback(pl.Callback):
|
|||||||
|
|
||||||
|
|
||||||
def add_generic_args(parser, root_dir) -> None:
|
def add_generic_args(parser, root_dir) -> None:
|
||||||
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
# To allow all pl args uncomment the following line
|
||||||
|
# parser = pl.Trainer.add_argparse_args(parser)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from utils import load_json
|
|||||||
|
|
||||||
|
|
||||||
MODEL_NAME = MBART_TINY
|
MODEL_NAME = MBART_TINY
|
||||||
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
|
|
||||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||||
|
|
||||||
|
|
||||||
@@ -99,7 +98,7 @@ def test_train_mbart_cc25_enro_script():
|
|||||||
assert expected_key in ckpt["state_dict"]
|
assert expected_key in ckpt["state_dict"]
|
||||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||||
|
|
||||||
# TODO(SS): turn on args.do_predict when PL bug fixed.
|
# TODO: turn on args.do_predict when PL bug fixed.
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
contents = {os.path.basename(p) for p in contents}
|
contents = {os.path.basename(p) for p in contents}
|
||||||
assert "test_generations.txt" in contents
|
assert "test_generations.txt" in contents
|
||||||
@@ -178,7 +177,7 @@ def test_opus_mt_distill_script():
|
|||||||
assert expected_key in ckpt["state_dict"]
|
assert expected_key in ckpt["state_dict"]
|
||||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||||
|
|
||||||
# TODO(SS): turn on args.do_predict when PL bug fixed.
|
# TODO: turn on args.do_predict when PL bug fixed.
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
contents = {os.path.basename(p) for p in contents}
|
contents = {os.path.basename(p) for p in contents}
|
||||||
assert "test_generations.txt" in contents
|
assert "test_generations.txt" in contents
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ def test_finetune_trainer():
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_finetune_trainer_slow():
|
def test_finetune_trainer_slow():
|
||||||
# TODO(SS): This will fail on devices with more than 1 GPU.
|
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ CHEAP_ARGS = {
|
|||||||
"student_decoder_layers": 1,
|
"student_decoder_layers": 1,
|
||||||
"val_check_interval": 1.0,
|
"val_check_interval": 1.0,
|
||||||
"output_dir": "",
|
"output_dir": "",
|
||||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||||
"no_teacher": False,
|
"no_teacher": False,
|
||||||
"fp16_opt_level": "O1",
|
"fp16_opt_level": "O1",
|
||||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||||
|
|||||||
@@ -54,8 +54,6 @@ def rename_state_dict_key(k):
|
|||||||
|
|
||||||
# See appendix C of paper for all hyperparams
|
# See appendix C of paper for all hyperparams
|
||||||
|
|
||||||
# TODO(SS): one constant
|
|
||||||
|
|
||||||
|
|
||||||
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
def convert_pegasus(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
||||||
cfg_kwargs = DEFAULTS.copy()
|
cfg_kwargs = DEFAULTS.copy()
|
||||||
|
|||||||
@@ -154,11 +154,8 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
if max_target_length is not None:
|
if max_target_length is not None:
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
tokenizer_kwargs["max_length"] = max_target_length
|
||||||
# TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id
|
|
||||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||||
model_inputs["labels"] = labels
|
model_inputs["labels"] = labels
|
||||||
# for k, v in decoder_inputs.items():
|
|
||||||
# model_inputs[f"decoder_{k}"] = v
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@@ -169,10 +166,6 @@ class PegasusTokenizerFast(ReformerTokenizerFast):
|
|||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
slow_tokenizer_class = PegasusTokenizer
|
slow_tokenizer_class = PegasusTokenizer
|
||||||
|
|
||||||
# def num_special_tokens_to_add(self, pair=False):
|
|
||||||
# """Just EOS"""
|
|
||||||
# return 1
|
|
||||||
|
|
||||||
def _special_token_mask(self, seq):
|
def _special_token_mask(self, seq):
|
||||||
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
|
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
|
||||||
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
|
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
|
||||||
@@ -236,9 +229,6 @@ class PegasusTokenizerFast(ReformerTokenizerFast):
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
if max_target_length is not None:
|
if max_target_length is not None:
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
tokenizer_kwargs["max_length"] = max_target_length
|
||||||
# TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id
|
|
||||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||||
model_inputs["labels"] = labels
|
model_inputs["labels"] = labels
|
||||||
# for k, v in decoder_inputs.items():
|
|
||||||
# model_inputs[f"decoder_{k}"] = v
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -125,12 +125,9 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
# TODO(SS): fix the below in a separate PR
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_resize_embeddings = True # This requires inputs_dict['input_ids']
|
test_missing_keys = False
|
||||||
test_missing_keys = False # because FSMTForConditionalGeneration and FSMTModel now have identical state_dict
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ModelTester(self)
|
self.model_tester = ModelTester(self)
|
||||||
@@ -326,7 +323,6 @@ class FSMTHeadTests(unittest.TestCase):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
|
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
|
||||||
# TODO(SS): uneven length batches, empty inputs
|
|
||||||
|
|
||||||
def test_shift_tokens_right(self):
|
def test_shift_tokens_right(self):
|
||||||
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
|
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
|
||||||
|
|||||||
Reference in New Issue
Block a user