fix data processing in script
This commit is contained in:
@@ -58,12 +58,12 @@ class TextDataset(Dataset):
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
|
||||
def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
# Load features that have already been computed if present
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
|
||||
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
|
||||
)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
@@ -72,7 +72,7 @@ class TextDataset(Dataset):
|
||||
return
|
||||
|
||||
logger.info("Creating features from dataset at %s", data_dir)
|
||||
|
||||
self.examples = []
|
||||
datasets = ["cnn", "dailymail"]
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
@@ -91,21 +91,17 @@ class TextDataset(Dataset):
|
||||
except IndexError: # skip ill-formed stories
|
||||
continue
|
||||
|
||||
story = tokenizer_src.convert_tokens_to_ids(
|
||||
tokenizer_src.tokenize(story)
|
||||
)
|
||||
story = tokenizer.encode(story)
|
||||
story_seq = _fit_to_block_size(story, block_size)
|
||||
|
||||
summary = tokenizer_tgt.convert_tokens_to_ids(
|
||||
tokenizer_tgt.tokenize(summary)
|
||||
)
|
||||
summary = tokenizer.encode(summary)
|
||||
summary_seq = _fit_to_block_size(summary, block_size)
|
||||
|
||||
self.examples.append((story_seq, summary_seq))
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
|
||||
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size):
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
return sequence.extend([-1] * [block_size - len(sequence)])
|
||||
return sequence.extend([-1] * (block_size - len(sequence)))
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
|
||||
dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
@@ -293,29 +289,17 @@ def main():
|
||||
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_name_or_path",
|
||||
"--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint to initialize the decoder's weights with.",
|
||||
help="The model checkpoint to initialize the encoder and decoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_type",
|
||||
"--model_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The decoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint to initialize the encoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The encoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
default=5e-5,
|
||||
@@ -346,7 +330,7 @@ def main():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.encoder_type != "bert" or args.decoder_type != "bert":
|
||||
if args.model_type != "bert":
|
||||
raise ValueError(
|
||||
"Only the BERT architecture is currently supported for seq2seq."
|
||||
)
|
||||
@@ -358,11 +342,8 @@ def main():
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
|
||||
decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
|
||||
model = Model2Model.from_pretrained(
|
||||
args.encoder_name_or_path, args.decoder_name_or_path
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = Model2Model.from_pretrained(args.model_name_or_path)
|
||||
# model.to(device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
Reference in New Issue
Block a user