From 47a06d88a00c59ea1fb54e92178b3f5d2e8e8973 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 13:04:26 +0200 Subject: [PATCH] use two different tokenizers for storyand summary --- examples/run_seq2seq_finetuning.py | 54 ++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 94b29c3cd6..3e3cc34cb8 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -26,7 +26,7 @@ import numpy as np import torch from torch.utils.data import Dataset -from transformers import BertTokenizer +from transformers import AutoTokenizer, Model2Model logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ class TextDataset(Dataset): [2] https://github.com/abisee/cnn-dailymail/ """ - def __init_(self, tokenizer, data_dir="", block_size=512): + def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512): assert os.path.isdir(data_dir) # Load features that have already been computed if present @@ -90,15 +90,13 @@ class TextDataset(Dataset): except IndexError: # skip ill-formed stories continue - summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) - summary_seq = _fit_to_block_size(summary, block_size) - - story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) + story = tokenizer_src.convert_tokens_to_ids(tokenizer_src.tokenize(story)) story_seq = _fit_to_block_size(story, block_size) - self.examples.append( - tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) - ) + summary = tokenizer_tgt.convert_tokens_to_ids(tokenizer_tgt.tokenize(summary)) + summary_seq = _fit_to_block_size(summary, block_size) + + self.examples.append((story_seq, summary_seq)) logger.info("Saving features into cache file %s", cached_features_file) with open(cached_features_file, "wb") as sink: @@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size): return sequence.extend([-1] * [block_size - len(sequence)]) -def load_and_cache_examples(args, tokenizer): - dataset = TextDataset(tokenizer, file_path=args.data_dir) +def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt): + dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir) return dataset @@ -205,14 +203,35 @@ def main(): # Optional parameters parser.add_argument( - "--model_name_or_path", + "--decoder_name_or_path", default="bert-base-cased", type=str, - help="The model checkpoint for weights initialization.", + help="The model checkpoint to initialize the decoder's weights with.", + ) + parser.add_argument( + "--decoder_type", + default="bert", + type=str, + help="The decoder architecture to be fine-tuned.", + ) + parser.add_argument( + "--encoder_name_or_path", + default="bert-base-cased", + type=str, + help="The model checkpoint to initialize the encoder's weights with.", + ) + parser.add_argument( + "--encoder_type", + default="bert", + type=str, + help="The encoder architecture to be fine-tuned.", ) parser.add_argument("--seed", default=42, type=int) args = parser.parse_args() + if args.encoder_type != 'bert' or args.decoder_type != 'bert': + raise ValueError("Only the BERT architecture is currently supported for seq2seq.") + # Set up training device # device = torch.device("cpu") @@ -220,16 +239,15 @@ def main(): set_seed(args) # Load pretrained model and tokenizer - tokenizer_class = BertTokenizer - # config = config_class.from_pretrained(args.model_name_or_path) - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - # model = model_class.from_pretrained(args.model_name_or_path, config=config) + encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path) + decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path) + model = Model2Model.from_pretrained(args.encoder_name_or_path, args.decoder_name_or_path) # model.to(device) logger.info("Training/evaluation parameters %s", args) # Training - _ = load_and_cache_examples(args, tokenizer) + source, target = load_and_cache_examples(args, tokenizer) # global_step, tr_loss = train(args, train_dataset, model, tokenizer) # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)