From 9dab39feeab0d141e2353f2b6402e4f823eec2e0 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 21 Jul 2020 16:58:45 -0400 Subject: [PATCH] seq2seq/run_eval.py can take decoder_start_token_id (#5949) --- examples/seq2seq/finetune.py | 1 + examples/seq2seq/run_eval.py | 18 +++++++++++++++++- src/transformers/tokenization_utils_base.py | 19 +++++++++++++++++-- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 88b414860f..a0014b9835 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule): self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] + self.model.config.decoder_start_token_id = self.decoder_start_token_id if isinstance(self.tokenizer, MBartTokenizer): self.dataset_class = MBartDataset diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index a805fdcfa5..8248f8b94c 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -30,6 +30,7 @@ def generate_summaries_or_translations( device: str = DEFAULT_DEVICE, fp16=False, task="summarization", + decoder_start_token_id=None, **gen_kwargs, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") @@ -37,6 +38,8 @@ def generate_summaries_or_translations( model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) if fp16: model = model.half() + if decoder_start_token_id is None: + decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -48,7 +51,12 @@ def generate_summaries_or_translations( batch = [model.config.prefix + text for text in batch] batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device) input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) - summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) + summaries = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_start_token_id=decoder_start_token_id, + **gen_kwargs, + ) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) for hypothesis in dec: fout.write(hypothesis + "\n") @@ -66,6 +74,13 @@ def run_generate(): parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") + parser.add_argument( + "--decoder_start_token_id", + type=int, + default=None, + required=False, + help="decoder_start_token_id (otherwise will look at config)", + ) parser.add_argument( "--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all." ) @@ -83,6 +98,7 @@ def run_generate(): device=args.device, fp16=args.fp16, task=args.task, + decoder_start_token_id=args.decoder_start_token_id, ) if args.reference_path is None: return diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 8fc9f9199c..63d1351b08 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2255,8 +2255,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): return encoded_inputs - def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]: - return [self.decode(seq, **kwargs) for seq in sequences] + def batch_decode( + self, sequences: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True + ) -> List[str]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. + skip_special_tokens: if set to True, will replace special tokens. + clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. + """ + return [ + self.decode( + seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces + ) + for seq in sequences + ] def decode( self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True