diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index ad6d126165..38dcb2d005 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -58,12 +58,12 @@ class TextDataset(Dataset): [2] https://github.com/abisee/cnn-dailymail/ """ - def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512): + def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512): assert os.path.isdir(data_dir) # Load features that have already been computed if present cached_features_file = os.path.join( - data_dir, "cached_lm_{}_{}".format(block_size, data_dir) + data_dir, "cached_lm_{}_{}".format(block_size, prefix) ) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) @@ -72,7 +72,7 @@ class TextDataset(Dataset): return logger.info("Creating features from dataset at %s", data_dir) - + self.examples = [] datasets = ["cnn", "dailymail"] for dataset in datasets: path_to_stories = os.path.join(data_dir, dataset, "stories") @@ -91,21 +91,17 @@ class TextDataset(Dataset): except IndexError: # skip ill-formed stories continue - story = tokenizer_src.convert_tokens_to_ids( - tokenizer_src.tokenize(story) - ) + story = tokenizer.encode(story) story_seq = _fit_to_block_size(story, block_size) - summary = tokenizer_tgt.convert_tokens_to_ids( - tokenizer_tgt.tokenize(summary) - ) + summary = tokenizer.encode(summary) summary_seq = _fit_to_block_size(summary, block_size) self.examples.append((story_seq, summary_seq)) logger.info("Saving features into cache file %s", cached_features_file) with open(cached_features_file, "wb") as sink: - pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL) + pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL) def __len__(self): return len(self.examples) @@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size): if len(sequence) > block_size: return sequence[:block_size] else: - return sequence.extend([-1] * [block_size - len(sequence)]) + return sequence.extend([-1] * (block_size - len(sequence))) -def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt): - dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir) +def load_and_cache_examples(args, tokenizer): + dataset = TextDataset(tokenizer, data_dir=args.data_dir) return dataset @@ -293,29 +289,17 @@ def main(): "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." ) parser.add_argument( - "--decoder_name_or_path", + "--model_name_or_path", default="bert-base-cased", type=str, - help="The model checkpoint to initialize the decoder's weights with.", + help="The model checkpoint to initialize the encoder and decoder's weights with.", ) parser.add_argument( - "--decoder_type", + "--model_type", default="bert", type=str, help="The decoder architecture to be fine-tuned.", ) - parser.add_argument( - "--encoder_name_or_path", - default="bert-base-cased", - type=str, - help="The model checkpoint to initialize the encoder's weights with.", - ) - parser.add_argument( - "--encoder_type", - default="bert", - type=str, - help="The encoder architecture to be fine-tuned.", - ) parser.add_argument( "--learning_rate", default=5e-5, @@ -346,7 +330,7 @@ def main(): ) args = parser.parse_args() - if args.encoder_type != "bert" or args.decoder_type != "bert": + if args.model_type != "bert": raise ValueError( "Only the BERT architecture is currently supported for seq2seq." ) @@ -358,11 +342,8 @@ def main(): set_seed(args) # Load pretrained model and tokenizer - encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path) - decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path) - model = Model2Model.from_pretrained( - args.encoder_name_or_path, args.decoder_name_or_path - ) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + model = Model2Model.from_pretrained(args.model_name_or_path) # model.to(device) logger.info("Training/evaluation parameters %s", args)