Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)

Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-07-18 13:57:33 -04:00
committed by GitHub
parent 4b506a37e3
commit 09a2f40684
6 changed files with 182 additions and 170 deletions

View File

@@ -7,6 +7,15 @@ For `bertabs` instructions, see `bertabs/README.md`.
### Data
XSUM Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
CNN/DailyMail data
```bash
@@ -17,18 +26,6 @@ tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
WMT16 English-Romanian Translation Data:
```bash
cd examples/seq2seq
@@ -40,7 +37,7 @@ export ENRO_DIR=${PWD}/wmt_en_ro
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Tips and Tricks
General Tips:
@@ -64,6 +61,10 @@ Summarization Tips:
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18**
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
A new dataset is needed to support multilingual tasks.
### Summarization Finetuning
Run/modify `finetune.sh`
@@ -78,8 +79,6 @@ The following command should work on a 16GB GPU:
--model_name_or_path facebook/bart-large
```
### Translation Finetuning
First, follow the wmt_en_ro download instructions.
@@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
#### XSUM Shared Task
Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
WANDB_PROJECT='hf_xsum' ./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger_name wandb
```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
### Evaluation Commands
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.

View File

@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
try:
from .finetune import SummarizationModule
from .initialization_utils import init_student, copy_layers
from .utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from .finetune import main as ft_main
from .initialization_utils import init_student, copy_layers
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
except ImportError:
from finetune import SummarizationModule
from finetune import main as ft_main
from initialization_utils import init_student, copy_layers
from utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
class BartSummarizationDistiller(SummarizationModule):
@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
return dataset
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
if mask is not None:
# mask has False at padding_idx

View File

@@ -21,7 +21,6 @@ try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
@@ -32,12 +31,17 @@ try:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError:
from utils import (
Seq2SeqDataset,
MBartDataset,
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
@@ -48,7 +52,6 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
assert_all_frozen,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(
input_ids=source_ids,
@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> SummarizationDataset:
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset(
dataset = self.dataset_class(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)

View File

@@ -9,16 +9,17 @@ from unittest.mock import patch
import pytest
import torch
from pytest import param
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoTokenizer, MBartTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
@@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing_eps": 0.2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
@@ -80,11 +82,11 @@ CHEAP_ARGS = {
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
content = "\n".join(articles)
Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
@@ -208,7 +210,7 @@ def test_run_eval_bart(model):
@pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
@@ -260,22 +262,50 @@ def test_pack_dataset():
assert orig_paths == new_paths
@pytest.mark.parametrize(
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_dataset(tok):
def test_mbart_dataset_truncation():
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc = 4
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=trunc,
max_target_length=1000, # ignored
src_lang=src_lang,
tgt_lang=tgt_lang,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert isinstance(batch, dict)
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == trunc
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == trunc
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
assert max_len_target > trunc # Truncated
assert max_len_source > trunc
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
def test_summarization_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
tgt_lang="ro_RO",
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
@@ -286,3 +316,4 @@ def test_dataset(tok):
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated
break # No need to test every batch

View File

@@ -1,7 +1,9 @@
import itertools
import json
import linecache
import os
import pickle
import warnings
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
@@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
data_path,
max_length,
pad_to_max_length=True,
return_tensors="pt",
overwrite_cache=False,
prefix="",
tok_name="",
):
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
examples = torch.load(cache_path)
assert isinstance(examples, list)
return examples
except Exception:
print(f"failed to load from {cache_path}, retokenizing {data_path}")
data_path = Path(data_path)
lns = lmap(str.strip, data_path.open().readlines())
lns = [prefix + text for text in lns]
assert lns, f"found empty file at {data_path}"
examples = []
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
tokenized = tokenizer(
[text],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**extra_kw,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**extra_kw,
)
def lmap(f: Callable, x: Iterable) -> List:
@@ -80,73 +52,111 @@ def trim_batch(
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class SummarizationDataset(Dataset):
class Seq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
max_source_length=1024,
max_target_length=56,
n_obs=None,
overwrite_cache=False,
prefix="",
src_lang=None,
tgt_lang=None,
prefix="",
):
super().__init__()
# FIXME: the rstrip logic strips all the chars, it seems.
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
if hasattr(tokenizer, "set_lang") and src_lang is not None:
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
max_source_length,
overwrite_cache=overwrite_cache,
prefix=prefix,
tok_name=tok_name,
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.src_lens = self.get_char_lens(self.src_file)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix
if n_obs is not None:
self.source = self.source[:n_obs]
self.target = self.target[:n_obs]
self.pad_token_id = tokenizer.pad_token_id
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self):
return len(self.source)
return len(self.src_lens)
def __getitem__(self, index):
source_ids = self.source[index]["input_ids"].squeeze()
target_ids = self.target[index]["input_ids"].squeeze()
src_mask = self.source[index]["attention_mask"].squeeze()
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
}
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch) -> dict:
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
}
return batch
def make_sortish_sampler(self, batch_size):
lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source]
return SortishSampler(lens, batch_size)
return SortishSampler(self.src_lens, batch_size)
class MBartDataset(Seq2SeqDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:
warnings.warn(
f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides."
)
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": source_line,
"src_texts": tgt_line,
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_translation_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
)
return batch_encoding.data
class SortishSampler(Sampler):