[s2s]Use prepare_translation_batch for Marian finetuning (#6293)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-08-06 14:58:38 -04:00
committed by GitHub
parent 2f2aa0c89c
commit 2804fff839
5 changed files with 22 additions and 12 deletions

View File

@@ -14,7 +14,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
try:
@@ -32,7 +32,7 @@ try:
ROUGE_KEYS,
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
TranslationDataset,
label_smoothed_nll_loss,
)
@@ -40,7 +40,7 @@ try:
except ImportError:
from utils import (
Seq2SeqDataset,
MBartDataset,
TranslationDataset,
assert_all_frozen,
use_task_specific_params,
lmap,
@@ -108,8 +108,8 @@ class SummarizationModule(BaseTransformer):
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
self.dataset_class = TranslationDataset
else:
self.dataset_class = Seq2SeqDataset