Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)
Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
@@ -9,16 +9,17 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest import param
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, MBartTokenizer
|
||||
from transformers.testing_utils import require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import SummarizationDataset, lmap, load_json
|
||||
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"label_smoothing_eps": 0.2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
@@ -80,11 +82,11 @@ CHEAP_ARGS = {
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
with path.open("w") as f:
|
||||
f.write("\n".join(articles))
|
||||
content = "\n".join(articles)
|
||||
Path(path).open("w").writelines(content)
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
@@ -260,22 +262,50 @@ def test_pack_dataset():
|
||||
assert orig_paths == new_paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
def test_dataset(tok):
|
||||
def test_mbart_dataset_truncation():
|
||||
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc = 4
|
||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||
train_dataset = MBartDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=trunc,
|
||||
max_target_length=1000, # ignored
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, dict)
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == trunc
|
||||
# show that targets are the same len
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc
|
||||
# check language codes in correct place
|
||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
||||
|
||||
assert max_len_target > trunc # Truncated
|
||||
assert max_len_source > trunc
|
||||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
|
||||
def test_summarization_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
tgt_lang="ro_RO",
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
@@ -286,3 +316,4 @@ def test_dataset(tok):
|
||||
# show that targets were truncated
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
break # No need to test every batch
|
||||
|
||||
Reference in New Issue
Block a user