prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)

* broken test

* batch parity

* tests pass

* boom boom

* boom boom

* split out bart tokenizer tests

* fix tests

* boom boom

* Fixed dataset bug

* Fix marian

* Undo extra

* Get marian working

* Fix t5 tok tests

* Test passing

* Cleanup

* better assert msg

* require torch

* Fix mbart tests

* undo extra decoder_attn_mask change

* Fix import

* pegasus tokenizer can ignore src_lang kwargs

* unused kwarg test cov

* boom boom

* add todo for pegasus issue

* cover one word translation edge case

* Cleanup

* doc
This commit is contained in:
Sam Shleifer
2020-08-28 11:15:17 -04:00
committed by GitHub
parent cb276b41de
commit 9336086ab5
20 changed files with 429 additions and 290 deletions

View File

@@ -10,18 +10,18 @@ from unittest.mock import patch
import pytest
import pytorch_lightning as pl
import torch
from pytest import param
from torch.utils.data import DataLoader
import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
@@ -452,18 +452,27 @@ def test_pack_dataset():
assert orig_paths == new_paths
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
def test_mbart_dataset_truncation(tok_name):
@pytest.mark.parametrize(
["tok_name"],
[
pytest.param(MBART_TINY),
pytest.param(MARIAN_TINY),
pytest.param(T5_TINY),
pytest.param(BART_TINY),
pytest.param("google/pegasus-xsum"),
],
)
def test_seq2seq_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4
max_tgt_len = 8
assert max_len_target > max_src_len # Truncated
assert max_len_source > max_src_len
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = TranslationDataset(
assert max_len_target > max_src_len # Will be truncated
assert max_len_source > max_src_len # Will be truncated
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
@@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name):
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
if tok_name == MARIAN_TINY:
assert batch["labels"].shape[1] == max_tgt_len
if tok_name != MBART_TINY:
continue
# check language codes in correct place
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
@@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name):
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
def test_summarization_dataset_truncation(tok):
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = Seq2SeqDataset(
train_dataset = LegacySeq2SeqDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
@@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok):
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert batch["labels"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch