Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)

Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-07-18 13:57:33 -04:00
committed by GitHub
parent 4b506a37e3
commit 09a2f40684
6 changed files with 182 additions and 170 deletions

View File

@@ -9,16 +9,17 @@ from unittest.mock import patch
import pytest
import torch
from pytest import param
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoTokenizer, MBartTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing_eps": 0.2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
@@ -80,11 +82,11 @@ CHEAP_ARGS = {
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
content = "\n".join(articles)
Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
@pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
@@ -260,22 +262,50 @@ def test_pack_dataset():
assert orig_paths == new_paths
@pytest.mark.parametrize(
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_dataset(tok):
def test_mbart_dataset_truncation():
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc = 4
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=trunc,
max_target_length=1000, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert isinstance(batch, dict)
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == trunc
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == trunc
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
assert max_len_target > trunc # Truncated
assert max_len_source > trunc
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
def test_summarization_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
tgt_lang="ro_RO",
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
@@ -286,3 +316,4 @@ def test_dataset(tok):
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch