[s2s]Use prepare_translation_batch for Marian finetuning (#6293)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-08-06 14:58:38 -04:00
committed by GitHub
parent 2f2aa0c89c
commit 2804fff839
5 changed files with 22 additions and 12 deletions

View File

@@ -14,14 +14,14 @@ from pytest import param
from torch.utils.data import DataLoader
import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
@@ -406,8 +406,9 @@ def test_pack_dataset():
assert orig_paths == new_paths
def test_mbart_dataset_truncation():
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
def test_mbart_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
@@ -416,7 +417,7 @@ def test_mbart_dataset_truncation():
assert max_len_target > max_src_len # Truncated
assert max_len_source > max_src_len
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
train_dataset = TranslationDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
@@ -433,6 +434,8 @@ def test_mbart_dataset_truncation():
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
if tok_name == MARIAN_TINY:
continue
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id