[s2s]Use prepare_translation_batch for Marian finetuning (#6293)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
@@ -32,7 +32,7 @@ try:
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
TranslationDataset,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ try:
|
||||
except ImportError:
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
lmap,
|
||||
@@ -108,8 +108,8 @@ class SummarizationModule(BaseTransformer):
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||
if isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.dataset_class = MBartDataset
|
||||
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
|
||||
self.dataset_class = TranslationDataset
|
||||
else:
|
||||
self.dataset_class = Seq2SeqDataset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user