seq2seq/run_eval.py can take decoder_start_token_id (#5949)
This commit is contained in:
@@ -327,6 +327,7 @@ class TranslationModule(SummarizationModule):
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||
if isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.dataset_class = MBartDataset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user