fix data processing in script
This commit is contained in:
@@ -58,12 +58,12 @@ class TextDataset(Dataset):
|
|||||||
[2] https://github.com/abisee/cnn-dailymail/
|
[2] https://github.com/abisee/cnn-dailymail/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
|
def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512):
|
||||||
assert os.path.isdir(data_dir)
|
assert os.path.isdir(data_dir)
|
||||||
|
|
||||||
# Load features that have already been computed if present
|
# Load features that have already been computed if present
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
|
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
|
||||||
)
|
)
|
||||||
if os.path.exists(cached_features_file):
|
if os.path.exists(cached_features_file):
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
@@ -72,7 +72,7 @@ class TextDataset(Dataset):
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Creating features from dataset at %s", data_dir)
|
logger.info("Creating features from dataset at %s", data_dir)
|
||||||
|
self.examples = []
|
||||||
datasets = ["cnn", "dailymail"]
|
datasets = ["cnn", "dailymail"]
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||||
@@ -91,21 +91,17 @@ class TextDataset(Dataset):
|
|||||||
except IndexError: # skip ill-formed stories
|
except IndexError: # skip ill-formed stories
|
||||||
continue
|
continue
|
||||||
|
|
||||||
story = tokenizer_src.convert_tokens_to_ids(
|
story = tokenizer.encode(story)
|
||||||
tokenizer_src.tokenize(story)
|
|
||||||
)
|
|
||||||
story_seq = _fit_to_block_size(story, block_size)
|
story_seq = _fit_to_block_size(story, block_size)
|
||||||
|
|
||||||
summary = tokenizer_tgt.convert_tokens_to_ids(
|
summary = tokenizer.encode(summary)
|
||||||
tokenizer_tgt.tokenize(summary)
|
|
||||||
)
|
|
||||||
summary_seq = _fit_to_block_size(summary, block_size)
|
summary_seq = _fit_to_block_size(summary, block_size)
|
||||||
|
|
||||||
self.examples.append((story_seq, summary_seq))
|
self.examples.append((story_seq, summary_seq))
|
||||||
|
|
||||||
logger.info("Saving features into cache file %s", cached_features_file)
|
logger.info("Saving features into cache file %s", cached_features_file)
|
||||||
with open(cached_features_file, "wb") as sink:
|
with open(cached_features_file, "wb") as sink:
|
||||||
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.examples)
|
return len(self.examples)
|
||||||
@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size):
|
|||||||
if len(sequence) > block_size:
|
if len(sequence) > block_size:
|
||||||
return sequence[:block_size]
|
return sequence[:block_size]
|
||||||
else:
|
else:
|
||||||
return sequence.extend([-1] * [block_size - len(sequence)])
|
return sequence.extend([-1] * (block_size - len(sequence)))
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
|
def load_and_cache_examples(args, tokenizer):
|
||||||
dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
|
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -293,29 +289,17 @@ def main():
|
|||||||
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
|
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder_name_or_path",
|
"--model_name_or_path",
|
||||||
default="bert-base-cased",
|
default="bert-base-cased",
|
||||||
type=str,
|
type=str,
|
||||||
help="The model checkpoint to initialize the decoder's weights with.",
|
help="The model checkpoint to initialize the encoder and decoder's weights with.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder_type",
|
"--model_type",
|
||||||
default="bert",
|
default="bert",
|
||||||
type=str,
|
type=str,
|
||||||
help="The decoder architecture to be fine-tuned.",
|
help="The decoder architecture to be fine-tuned.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--encoder_name_or_path",
|
|
||||||
default="bert-base-cased",
|
|
||||||
type=str,
|
|
||||||
help="The model checkpoint to initialize the encoder's weights with.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder_type",
|
|
||||||
default="bert",
|
|
||||||
type=str,
|
|
||||||
help="The encoder architecture to be fine-tuned.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--learning_rate",
|
"--learning_rate",
|
||||||
default=5e-5,
|
default=5e-5,
|
||||||
@@ -346,7 +330,7 @@ def main():
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.encoder_type != "bert" or args.decoder_type != "bert":
|
if args.model_type != "bert":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only the BERT architecture is currently supported for seq2seq."
|
"Only the BERT architecture is currently supported for seq2seq."
|
||||||
)
|
)
|
||||||
@@ -358,11 +342,8 @@ def main():
|
|||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||||
decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
|
model = Model2Model.from_pretrained(args.model_name_or_path)
|
||||||
model = Model2Model.from_pretrained(
|
|
||||||
args.encoder_name_or_path, args.decoder_name_or_path
|
|
||||||
)
|
|
||||||
# model.to(device)
|
# model.to(device)
|
||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|||||||
Reference in New Issue
Block a user