[cleanup] remove redundant code in SummarizationDataset (#5119)
This commit is contained in:
@@ -13,8 +13,6 @@ from torch import nn
|
|||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import BartTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def encode_file(
|
def encode_file(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
|
|||||||
prefix="",
|
prefix="",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
|
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||||
self.source = encode_file(
|
self.source = encode_file(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
os.path.join(data_dir, type_path + ".source"),
|
os.path.join(data_dir, type_path + ".source"),
|
||||||
@@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
tok_name=tok_name,
|
tok_name=tok_name,
|
||||||
)
|
)
|
||||||
if type_path == "train":
|
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
|
||||||
else:
|
|
||||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
|
||||||
|
|
||||||
self.target = encode_file(
|
self.target = encode_file(
|
||||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||||
)
|
)
|
||||||
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
|
||||||
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
|
||||||
if n_obs is not None:
|
if n_obs is not None:
|
||||||
self.source = self.source[:n_obs]
|
self.source = self.source[:n_obs]
|
||||||
self.target = self.target[:n_obs]
|
self.target = self.target[:n_obs]
|
||||||
|
|||||||
Reference in New Issue
Block a user