[examples] SummarizationDataset cleanup (#3451)
This commit is contained in:
@@ -19,30 +19,20 @@ class BartSystem(BaseTransformer):
|
|||||||
mode = "language-modeling"
|
mode = "language-modeling"
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode)
|
super().__init__(hparams, num_labels=None, mode=self.mode, output_past=False)
|
||||||
|
|
||||||
def forward(
|
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
|
||||||
self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None
|
|
||||||
):
|
|
||||||
return self.model(
|
return self.model(
|
||||||
input_ids,
|
input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
|
||||||
lm_labels=lm_labels,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _step(self, batch):
|
def _step(self, batch):
|
||||||
y = batch["target_ids"]
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"]
|
||||||
y_ids = y[:, :-1].contiguous()
|
y_ids = y[:, :-1].contiguous()
|
||||||
lm_labels = y[:, 1:].clone()
|
lm_labels = y[:, 1:].clone()
|
||||||
lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100
|
lm_labels[y[:, 1:] == pad_token_id] = -100
|
||||||
outputs = self(
|
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,)
|
||||||
input_ids=batch["source_ids"],
|
|
||||||
attention_mask=batch["source_mask"],
|
|
||||||
decoder_input_ids=y_ids,
|
|
||||||
lm_labels=lm_labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
|
|
||||||
@@ -64,9 +54,13 @@ class BartSystem(BaseTransformer):
|
|||||||
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
|
# NOTE: this generation will not use the cache.
|
||||||
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
|
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
batch["source_ids"],
|
source_ids,
|
||||||
attention_mask=batch["source_mask"],
|
source_mask,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=80,
|
max_length=80,
|
||||||
repetition_penalty=2.5,
|
repetition_penalty=2.5,
|
||||||
@@ -77,10 +71,7 @@ class BartSystem(BaseTransformer):
|
|||||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||||
for g in generated_ids
|
for g in generated_ids
|
||||||
]
|
]
|
||||||
target = [
|
target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
|
||||||
self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
||||||
for t in batch["target_ids"]
|
|
||||||
]
|
|
||||||
loss = self._step(batch)
|
loss = self._step(batch)
|
||||||
|
|
||||||
return {"val_loss": loss, "preds": preds, "target": target}
|
return {"val_loss": loss, "preds": preds, "target": target}
|
||||||
@@ -101,11 +92,21 @@ class BartSystem(BaseTransformer):
|
|||||||
|
|
||||||
return self.test_end(outputs)
|
return self.test_end(outputs)
|
||||||
|
|
||||||
def train_dataloader(self):
|
@property
|
||||||
train_dataset = SummarizationDataset(
|
def dataset_kwargs(self):
|
||||||
self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length
|
return dict(
|
||||||
|
data_dir=self.hparams.data_dir,
|
||||||
|
max_source_length=self.hparams.max_source_length,
|
||||||
|
max_target_length=self.hparams.max_target_length,
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size)
|
|
||||||
|
def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
|
||||||
|
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
def train_dataloader(self) -> DataLoader:
|
||||||
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size)
|
||||||
t_total = (
|
t_total = (
|
||||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
||||||
// self.hparams.gradient_accumulation_steps
|
// self.hparams.gradient_accumulation_steps
|
||||||
@@ -117,29 +118,30 @@ class BartSystem(BaseTransformer):
|
|||||||
self.lr_scheduler = scheduler
|
self.lr_scheduler = scheduler
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self) -> DataLoader:
|
||||||
val_dataset = SummarizationDataset(
|
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||||
self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length
|
|
||||||
)
|
|
||||||
return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size)
|
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self) -> DataLoader:
|
||||||
test_dataset = SummarizationDataset(
|
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||||
self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length
|
|
||||||
)
|
|
||||||
return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||||
# Add BART specific options
|
# Add BART specific options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_seq_length",
|
"--max_source_length",
|
||||||
default=1024,
|
default=1024,
|
||||||
type=int,
|
type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_target_length",
|
||||||
|
default=56,
|
||||||
|
type=int,
|
||||||
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir",
|
"--data_dir",
|
||||||
@@ -158,7 +160,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# If output_dir not provided, a folder will be generated in pwd
|
# If output_dir not provided, a folder will be generated in pwd
|
||||||
if args.output_dir is None:
|
if not args.output_dir:
|
||||||
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -5,28 +5,57 @@ import unittest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
from .evaluate_cnn import run_generate
|
from .evaluate_cnn import run_generate
|
||||||
|
from .utils import SummarizationDataset
|
||||||
|
|
||||||
|
|
||||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_articles(path: Path, articles: list):
|
||||||
|
with path.open("w") as f:
|
||||||
|
f.write("\n".join(articles))
|
||||||
|
|
||||||
|
|
||||||
class TestBartExamples(unittest.TestCase):
|
class TestBartExamples(unittest.TestCase):
|
||||||
def test_bart_cnn_cli(self):
|
def test_bart_cnn_cli(self):
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||||
with tmp.open("w") as f:
|
|
||||||
f.write("\n".join(articles))
|
|
||||||
|
|
||||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||||
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||||
|
_dump_articles(tmp, articles)
|
||||||
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
||||||
|
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_generate()
|
run_generate()
|
||||||
self.assertTrue(Path(output_file_name).exists())
|
self.assertTrue(output_file_name.exists())
|
||||||
|
|
||||||
|
def test_bart_summarization_dataset(self):
|
||||||
|
tmp_dir = Path(tempfile.gettempdir())
|
||||||
|
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||||
|
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||||
|
_dump_articles((tmp_dir / "train.source"), articles)
|
||||||
|
_dump_articles((tmp_dir / "train.target"), summaries)
|
||||||
|
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||||
|
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
|
||||||
|
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
|
||||||
|
trunc_target = 4
|
||||||
|
train_dataset = SummarizationDataset(
|
||||||
|
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
|
for batch in dataloader:
|
||||||
|
self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape)
|
||||||
|
# show that articles were trimmed.
|
||||||
|
self.assertEqual(batch["source_ids"].shape[1], max_len_source)
|
||||||
|
self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly
|
||||||
|
|
||||||
|
# show that targets were truncated
|
||||||
|
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
|
||||||
|
self.assertGreater(max_len_target, trunc_target) # Truncated
|
||||||
|
|||||||
@@ -1,35 +1,35 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from transformers.tokenization_utils import trim_batch
|
||||||
|
|
||||||
|
|
||||||
|
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||||
|
examples = []
|
||||||
|
with open(data_path, "r") as f:
|
||||||
|
for text in f.readlines():
|
||||||
|
tokenized = tokenizer.batch_encode_plus(
|
||||||
|
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
|
||||||
|
)
|
||||||
|
examples.append(tokenized)
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
class SummarizationDataset(Dataset):
|
class SummarizationDataset(Dataset):
|
||||||
def __init__(self, tokenizer, data_dir="./cnn-dailymail/cnn_dm/", type_path="train", block_size=1024):
|
def __init__(
|
||||||
super(SummarizationDataset,).__init__()
|
self,
|
||||||
|
tokenizer,
|
||||||
|
data_dir="./cnn-dailymail/cnn_dm/",
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=1024,
|
||||||
|
max_target_length=56,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
||||||
self.source = []
|
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
||||||
self.target = []
|
|
||||||
|
|
||||||
print("loading " + type_path + " source.")
|
|
||||||
|
|
||||||
with open(os.path.join(data_dir, type_path + ".source"), "r") as f:
|
|
||||||
for text in f.readlines(): # each text is a line and a full story
|
|
||||||
tokenized = tokenizer.batch_encode_plus(
|
|
||||||
[text], max_length=block_size, pad_to_max_length=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
self.source.append(tokenized)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
print("loading " + type_path + " target.")
|
|
||||||
|
|
||||||
with open(os.path.join(data_dir, type_path + ".target"), "r") as f:
|
|
||||||
for text in f.readlines(): # each text is a line and a summary
|
|
||||||
tokenized = tokenizer.batch_encode_plus(
|
|
||||||
[text], max_length=56, pad_to_max_length=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
self.target.append(tokenized)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.source)
|
return len(self.source)
|
||||||
@@ -37,7 +37,20 @@ class SummarizationDataset(Dataset):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
source_ids = self.source[index]["input_ids"].squeeze()
|
source_ids = self.source[index]["input_ids"].squeeze()
|
||||||
target_ids = self.target[index]["input_ids"].squeeze()
|
target_ids = self.target[index]["input_ids"].squeeze()
|
||||||
|
src_mask = self.source[index]["attention_mask"].squeeze()
|
||||||
src_mask = self.source[index]["attention_mask"].squeeze() # might need to squeeze
|
|
||||||
|
|
||||||
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def trim_seq2seq_batch(batch, pad_token_id):
|
||||||
|
y = trim_batch(batch["target_ids"], pad_token_id)
|
||||||
|
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
|
||||||
|
return source_ids, source_mask, y
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
input_ids = torch.stack([x["source_ids"] for x in batch])
|
||||||
|
masks = torch.stack([x["source_mask"] for x in batch])
|
||||||
|
target_ids = torch.stack([x["target_ids"] for x in batch])
|
||||||
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
y = trim_batch(target_ids, pad_token_id)
|
||||||
|
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||||
|
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}
|
||||||
|
|||||||
@@ -47,27 +47,29 @@ def set_seed(args):
|
|||||||
|
|
||||||
|
|
||||||
class BaseTransformer(pl.LightningModule):
|
class BaseTransformer(pl.LightningModule):
|
||||||
def __init__(self, hparams, num_labels=None, mode="base"):
|
def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs):
|
||||||
"Initialize a model."
|
"Initialize a model."
|
||||||
|
|
||||||
super(BaseTransformer, self).__init__()
|
super(BaseTransformer, self).__init__()
|
||||||
self.hparams = hparams
|
self.hparams = hparams
|
||||||
|
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||||
self.hparams.model_type = self.hparams.model_type.lower()
|
self.hparams.model_type = self.hparams.model_type.lower()
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
cache_dir=cache_dir,
|
||||||
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||||
do_lower_case=self.hparams.do_lower_case,
|
do_lower_case=self.hparams.do_lower_case,
|
||||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
model = MODEL_MODES[mode].from_pretrained(
|
model = MODEL_MODES[mode].from_pretrained(
|
||||||
self.hparams.model_name_or_path,
|
self.hparams.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
self.config, self.tokenizer, self.model = config, tokenizer, model
|
self.config, self.tokenizer, self.model = config, tokenizer, model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user