use two different tokenizers for storyand summary
This commit is contained in:
@@ -26,7 +26,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from transformers import BertTokenizer
|
from transformers import AutoTokenizer, Model2Model
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ class TextDataset(Dataset):
|
|||||||
[2] https://github.com/abisee/cnn-dailymail/
|
[2] https://github.com/abisee/cnn-dailymail/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init_(self, tokenizer, data_dir="", block_size=512):
|
def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
|
||||||
assert os.path.isdir(data_dir)
|
assert os.path.isdir(data_dir)
|
||||||
|
|
||||||
# Load features that have already been computed if present
|
# Load features that have already been computed if present
|
||||||
@@ -90,15 +90,13 @@ class TextDataset(Dataset):
|
|||||||
except IndexError: # skip ill-formed stories
|
except IndexError: # skip ill-formed stories
|
||||||
continue
|
continue
|
||||||
|
|
||||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
story = tokenizer_src.convert_tokens_to_ids(tokenizer_src.tokenize(story))
|
||||||
summary_seq = _fit_to_block_size(summary, block_size)
|
|
||||||
|
|
||||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
|
||||||
story_seq = _fit_to_block_size(story, block_size)
|
story_seq = _fit_to_block_size(story, block_size)
|
||||||
|
|
||||||
self.examples.append(
|
summary = tokenizer_tgt.convert_tokens_to_ids(tokenizer_tgt.tokenize(summary))
|
||||||
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
summary_seq = _fit_to_block_size(summary, block_size)
|
||||||
)
|
|
||||||
|
self.examples.append((story_seq, summary_seq))
|
||||||
|
|
||||||
logger.info("Saving features into cache file %s", cached_features_file)
|
logger.info("Saving features into cache file %s", cached_features_file)
|
||||||
with open(cached_features_file, "wb") as sink:
|
with open(cached_features_file, "wb") as sink:
|
||||||
@@ -169,8 +167,8 @@ def _fit_to_block_size(sequence, block_size):
|
|||||||
return sequence.extend([-1] * [block_size - len(sequence)])
|
return sequence.extend([-1] * [block_size - len(sequence)])
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer):
|
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
|
||||||
dataset = TextDataset(tokenizer, file_path=args.data_dir)
|
dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -205,14 +203,35 @@ def main():
|
|||||||
|
|
||||||
# Optional parameters
|
# Optional parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--decoder_name_or_path",
|
||||||
default="bert-base-cased",
|
default="bert-base-cased",
|
||||||
type=str,
|
type=str,
|
||||||
help="The model checkpoint for weights initialization.",
|
help="The model checkpoint to initialize the decoder's weights with.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder_type",
|
||||||
|
default="bert",
|
||||||
|
type=str,
|
||||||
|
help="The decoder architecture to be fine-tuned.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder_name_or_path",
|
||||||
|
default="bert-base-cased",
|
||||||
|
type=str,
|
||||||
|
help="The model checkpoint to initialize the encoder's weights with.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder_type",
|
||||||
|
default="bert",
|
||||||
|
type=str,
|
||||||
|
help="The encoder architecture to be fine-tuned.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", default=42, type=int)
|
parser.add_argument("--seed", default=42, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.encoder_type != 'bert' or args.decoder_type != 'bert':
|
||||||
|
raise ValueError("Only the BERT architecture is currently supported for seq2seq.")
|
||||||
|
|
||||||
# Set up training device
|
# Set up training device
|
||||||
# device = torch.device("cpu")
|
# device = torch.device("cpu")
|
||||||
|
|
||||||
@@ -220,16 +239,15 @@ def main():
|
|||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
tokenizer_class = BertTokenizer
|
encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
|
||||||
# config = config_class.from_pretrained(args.model_name_or_path)
|
decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
model = Model2Model.from_pretrained(args.encoder_name_or_path, args.decoder_name_or_path)
|
||||||
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
|
||||||
# model.to(device)
|
# model.to(device)
|
||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
_ = load_and_cache_examples(args, tokenizer)
|
source, target = load_and_cache_examples(args, tokenizer)
|
||||||
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user