seq2seq/run_eval.py can take decoder_start_token_id (#5949)

This commit is contained in:
Sam Shleifer
2020-07-21 16:58:45 -04:00
committed by GitHub
parent 5b193b39b0
commit 9dab39feea
3 changed files with 35 additions and 3 deletions

View File

@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset