From 2db1e2f415ec276f02e16a463e79d42cac955295 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 18 Jun 2020 20:34:48 -0400 Subject: [PATCH] [cleanup] remove redundant code in SummarizationDataset (#5119) --- examples/summarization/utils.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/summarization/utils.py b/examples/summarization/utils.py index a375d823ed..d59e733ab0 100644 --- a/examples/summarization/utils.py +++ b/examples/summarization/utils.py @@ -13,8 +13,6 @@ from torch import nn from torch.utils.data import Dataset, Sampler from tqdm import tqdm -from transformers import BartTokenizer - def encode_file( tokenizer, @@ -85,7 +83,7 @@ class SummarizationDataset(Dataset): prefix="", ): super().__init__() - tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else "" + tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer") self.source = encode_file( tokenizer, os.path.join(data_dir, type_path + ".source"), @@ -94,16 +92,10 @@ class SummarizationDataset(Dataset): prefix=prefix, tok_name=tok_name, ) - if type_path == "train": - tgt_path = os.path.join(data_dir, type_path + ".target") - else: - tgt_path = os.path.join(data_dir, type_path + ".target") - + tgt_path = os.path.join(data_dir, type_path + ".target") self.target = encode_file( tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name ) - self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) - self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) if n_obs is not None: self.source = self.source[:n_obs] self.target = self.target[:n_obs]