prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -10,18 +10,18 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytest import param
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import lightning_base
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import SummarizationModule, main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -452,18 +452,27 @@ def test_pack_dataset():
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
|
||||
def test_mbart_dataset_truncation(tok_name):
|
||||
@pytest.mark.parametrize(
|
||||
["tok_name"],
|
||||
[
|
||||
pytest.param(MBART_TINY),
|
||||
pytest.param(MARIAN_TINY),
|
||||
pytest.param(T5_TINY),
|
||||
pytest.param(BART_TINY),
|
||||
pytest.param("google/pegasus-xsum"),
|
||||
],
|
||||
)
|
||||
def test_seq2seq_dataset_truncation(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
max_src_len = 4
|
||||
max_tgt_len = 8
|
||||
assert max_len_target > max_src_len # Truncated
|
||||
assert max_len_source > max_src_len
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||
train_dataset = TranslationDataset(
|
||||
assert max_len_target > max_src_len # Will be truncated
|
||||
assert max_len_source > max_src_len # Will be truncated
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
@@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name):
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_src_len
|
||||
# show that targets are the same len
|
||||
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
|
||||
if tok_name == MARIAN_TINY:
|
||||
assert batch["labels"].shape[1] == max_tgt_len
|
||||
if tok_name != MBART_TINY:
|
||||
continue
|
||||
# check language codes in correct place
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
@@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name):
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
|
||||
def test_summarization_dataset_truncation(tok):
|
||||
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
|
||||
def test_legacy_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = Seq2SeqDataset(
|
||||
train_dataset = LegacySeq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
@@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok):
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
||||
Reference in New Issue
Block a user