From 500be01c5d53af1f4dc6e20430f4591239a6281b Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 6 Oct 2020 16:11:56 -0400 Subject: [PATCH] [s2s] save first batch to json for debugging purposes (#6810) --- examples/seq2seq/finetune.py | 16 ++++++++++++++++ examples/seq2seq/test_seq2seq_examples.py | 4 ++++ 2 files changed, 20 insertions(+) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 65343b2f0c..b401add5cf 100755 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -33,6 +33,7 @@ from utils import ( lmap, pickle_save, save_git_info, + save_json, use_task_specific_params, ) @@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer): self.dataset_class = ( Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset ) + self.already_saved_batch = False self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams if self.hparams.eval_max_gen_length is not None: self.eval_max_length = self.hparams.eval_max_gen_length @@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer): self.eval_max_length = self.model.config.max_length self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric + def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: + """A debugging utility""" + readable_batch = { + k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items() + } + save_json(readable_batch, Path(self.output_dir) / "text_batch.json") + save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json") + + self.already_saved_batch = True + return readable_batch + def forward(self, input_ids, **kwargs): return self.model(input_ids, **kwargs) @@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer): decoder_input_ids = self.model._shift_right(tgt_ids) else: decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) + if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero + batch["decoder_input_ids"] = decoder_input_ids + self.save_readable_batch(batch) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) lm_logits = outputs[0] diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 1b6c505c94..e28acc3131 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -422,6 +422,10 @@ def test_finetune(model): assert bart.decoder.embed_tokens == bart.encoder.embed_tokens assert bart.decoder.embed_tokens == bart.shared + example_batch = load_json(module.output_dir / "text_batch.json") + assert isinstance(example_batch, dict) + assert len(example_batch) >= 4 + def test_finetune_extra_model_args(): args_d: dict = CHEAP_ARGS.copy()