[examples] SummarizationDataset cleanup (#3451)
This commit is contained in:
@@ -1,35 +1,35 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers.tokenization_utils import trim_batch
|
||||
|
||||
|
||||
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
examples = []
|
||||
with open(data_path, "r") as f:
|
||||
for text in f.readlines():
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
|
||||
)
|
||||
examples.append(tokenized)
|
||||
return examples
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(self, tokenizer, data_dir="./cnn-dailymail/cnn_dm/", type_path="train", block_size=1024):
|
||||
super(SummarizationDataset,).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir="./cnn-dailymail/cnn_dm/",
|
||||
type_path="train",
|
||||
max_source_length=1024,
|
||||
max_target_length=56,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.source = []
|
||||
self.target = []
|
||||
|
||||
print("loading " + type_path + " source.")
|
||||
|
||||
with open(os.path.join(data_dir, type_path + ".source"), "r") as f:
|
||||
for text in f.readlines(): # each text is a line and a full story
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], max_length=block_size, pad_to_max_length=True, return_tensors="pt"
|
||||
)
|
||||
self.source.append(tokenized)
|
||||
f.close()
|
||||
|
||||
print("loading " + type_path + " target.")
|
||||
|
||||
with open(os.path.join(data_dir, type_path + ".target"), "r") as f:
|
||||
for text in f.readlines(): # each text is a line and a summary
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], max_length=56, pad_to_max_length=True, return_tensors="pt"
|
||||
)
|
||||
self.target.append(tokenized)
|
||||
f.close()
|
||||
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
||||
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source)
|
||||
@@ -37,7 +37,20 @@ class SummarizationDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
source_ids = self.source[index]["input_ids"].squeeze()
|
||||
target_ids = self.target[index]["input_ids"].squeeze()
|
||||
|
||||
src_mask = self.source[index]["attention_mask"].squeeze() # might need to squeeze
|
||||
|
||||
src_mask = self.source[index]["attention_mask"].squeeze()
|
||||
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
|
||||
|
||||
@staticmethod
|
||||
def trim_seq2seq_batch(batch, pad_token_id):
|
||||
y = trim_batch(batch["target_ids"], pad_token_id)
|
||||
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
|
||||
return source_ids, source_mask, y
|
||||
|
||||
def collate_fn(self, batch):
|
||||
input_ids = torch.stack([x["source_ids"] for x in batch])
|
||||
masks = torch.stack([x["source_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["target_ids"] for x in batch])
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
y = trim_batch(target_ids, pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}
|
||||
|
||||
Reference in New Issue
Block a user