use two different tokenizers for storyand summary
This commit is contained in:
@@ -26,7 +26,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import BertTokenizer
|
||||
from transformers import AutoTokenizer, Model2Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,7 +57,7 @@ class TextDataset(Dataset):
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init_(self, tokenizer, data_dir="", block_size=512):
|
||||
def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
# Load features that have already been computed if present
|
||||
@@ -90,15 +90,13 @@ class TextDataset(Dataset):
|
||||
except IndexError: # skip ill-formed stories
|
||||
continue
|
||||
|
||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||
summary_seq = _fit_to_block_size(summary, block_size)
|
||||
|
||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||
story = tokenizer_src.convert_tokens_to_ids(tokenizer_src.tokenize(story))
|
||||
story_seq = _fit_to_block_size(story, block_size)
|
||||
|
||||
self.examples.append(
|
||||
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||
)
|
||||
summary = tokenizer_tgt.convert_tokens_to_ids(tokenizer_tgt.tokenize(summary))
|
||||
summary_seq = _fit_to_block_size(summary, block_size)
|
||||
|
||||
self.examples.append((story_seq, summary_seq))
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size):
|
||||
return sequence.extend([-1] * [block_size - len(sequence)])
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = TextDataset(tokenizer, file_path=args.data_dir)
|
||||
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
|
||||
dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
@@ -205,14 +203,35 @@ def main():
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
"--decoder_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint for weights initialization.",
|
||||
help="The model checkpoint to initialize the decoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The decoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint to initialize the encoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The encoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.encoder_type != 'bert' or args.decoder_type != 'bert':
|
||||
raise ValueError("Only the BERT architecture is currently supported for seq2seq.")
|
||||
|
||||
# Set up training device
|
||||
# device = torch.device("cpu")
|
||||
|
||||
@@ -220,16 +239,15 @@ def main():
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
tokenizer_class = BertTokenizer
|
||||
# config = config_class.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||
encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
|
||||
decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
|
||||
model = Model2Model.from_pretrained(args.encoder_name_or_path, args.decoder_name_or_path)
|
||||
# model.to(device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
_ = load_and_cache_examples(args, tokenizer)
|
||||
source, target = load_and_cache_examples(args, tokenizer)
|
||||
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user