wip
This commit is contained in:
@@ -20,7 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
# import modeling_pytorch
|
import modeling_pytorch
|
||||||
# import optimization
|
# import optimization
|
||||||
import tokenization_pytorch
|
import tokenization_pytorch
|
||||||
import torch
|
import torch
|
||||||
@@ -210,7 +210,40 @@ class DataProcessor(object):
|
|||||||
for line in reader:
|
for line in reader:
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
class MrpcProcessor(DataProcessor):
|
||||||
|
"""Processor for the MRPC data set (GLUE version)."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
print("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""See base class."""
|
||||||
|
return ["0", "1"]
|
||||||
|
|
||||||
|
def _create_examples(self, lines, set_type):
|
||||||
|
"""Creates examples for the training and dev sets."""
|
||||||
|
examples = []
|
||||||
|
for (i, line) in enumerate(lines):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
guid = "%s-%s" % (set_type, i)
|
||||||
|
text_a = tokenization_pytorch.convert_to_unicode(line[3])
|
||||||
|
text_b = tokenization_pytorch.convert_to_unicode(line[4])
|
||||||
|
label = tokenization_pytorch.convert_to_unicode(line[0])
|
||||||
|
examples.append(
|
||||||
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
class MnliProcessor(DataProcessor):
|
class MnliProcessor(DataProcessor):
|
||||||
"""Processor for the MultiNLI data set (GLUE version)."""
|
"""Processor for the MultiNLI data set (GLUE version)."""
|
||||||
@@ -395,13 +428,13 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
|
|||||||
|
|
||||||
|
|
||||||
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
||||||
num_train_steps, num_warmup_steps,
|
num_train_steps, num_warmup_steps, use_gpu,
|
||||||
use_one_hot_embeddings):
|
use_one_hot_embeddings):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
### ATTENTION - I removed the `use_tpu` argument
|
### ATTENTION - I removed the `use_tpu` argument
|
||||||
|
|
||||||
|
|
||||||
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
def input_fn_builder(features, seq_length, is_training, eval_drop_remainder):
|
||||||
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
||||||
|
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
@@ -422,21 +455,17 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
|||||||
num_examples = len(features)
|
num_examples = len(features)
|
||||||
|
|
||||||
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
|
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
|
||||||
d = {"input_ids":
|
d = torch.utils.data.TensorDataset({ ## BUG THIS IS NOT WORKING.... ###
|
||||||
torch.IntTensor(all_input_ids, device = device), #Requires_grad=False by default
|
"input_ids": torch.IntTensor(all_input_ids, device=device), #Requires_grad=False by default
|
||||||
"input_mask":
|
"input_mask": torch.IntTensor(all_input_mask, device=device),
|
||||||
torch.IntTensor(all_input_mask, device = device),
|
"segment_ids": torch.IntTensor(all_segment_ids, device=device),
|
||||||
"segment_ids":
|
"label_ids": torch.IntTensor(all_label_ids, device=device)
|
||||||
torch.IntTensor(all_segment_ids, device = device),
|
})
|
||||||
"label_ids":
|
|
||||||
torch.IntTensor(all_label_ids, device = device)
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_training:
|
|
||||||
d = d.repeat()
|
|
||||||
d = d.shuffle(buffer_size=100)
|
|
||||||
|
|
||||||
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
shuffle = True if training else False
|
||||||
|
d = torch.utils.data.DataLoader(dataset=d, batch_size=batch_size,
|
||||||
|
shuffle=shuffle,drop_last=drop_remainder)
|
||||||
|
# Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return input_fn
|
return input_fn
|
||||||
@@ -452,7 +481,7 @@ def main(_):
|
|||||||
if not args.do_train and not args.do_eval:
|
if not args.do_train and not args.do_eval:
|
||||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
bert_config = modeling.BertConfig.from_json_file(args.bert_config_file)
|
bert_config = modeling_pytorch.BertConfig.from_json_file(args.bert_config_file)
|
||||||
|
|
||||||
if args.max_seq_length > bert_config.max_position_embeddings:
|
if args.max_seq_length > bert_config.max_position_embeddings:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -461,7 +490,7 @@ def main(_):
|
|||||||
(args.max_seq_length, bert_config.max_position_embeddings))
|
(args.max_seq_length, bert_config.max_position_embeddings))
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ConfigurationError(f"Output directory ({args.output_dir}) already exists and is "
|
raise ValueError(f"Output directory ({args.output_dir}) already exists and is "
|
||||||
f"not empty.")
|
f"not empty.")
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -474,7 +503,7 @@ def main(_):
|
|||||||
|
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
|
|
||||||
tokenizer = tokenization.FullTokenizer(
|
tokenizer = tokenization_pytorch.FullTokenizer(
|
||||||
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
# tpu_cluster_resolver = None
|
# tpu_cluster_resolver = None
|
||||||
@@ -514,13 +543,12 @@ def main(_):
|
|||||||
|
|
||||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||||
# or GPU. - TO DO
|
# or GPU. - TO DO
|
||||||
for batch in
|
# estimator = tf.contrib.tpu.TPUEstimator(
|
||||||
estimator = tf.contrib.tpu.TPUEstimator(
|
# use_tpu=args.use_tpu,
|
||||||
use_tpu=args.use_tpu,
|
# model_fn=model_fn,
|
||||||
model_fn=model_fn,
|
# config=run_config,
|
||||||
config=run_config,
|
# train_batch_size=args.train_batch_size,
|
||||||
train_batch_size=args.train_batch_size,
|
# eval_batch_size=args.eval_batch_size)
|
||||||
eval_batch_size=args.eval_batch_size)
|
|
||||||
|
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
@@ -529,21 +557,27 @@ def main(_):
|
|||||||
logger.info(" Num examples = %d", len(train_examples))
|
logger.info(" Num examples = %d", len(train_examples))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_steps)
|
||||||
train_input_fn = input_fn_builder(
|
train_input = input_fn_builder(
|
||||||
features=train_features,
|
features=train_features,
|
||||||
seq_length=args.max_seq_length,
|
seq_length=args.max_seq_length,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
drop_remainder=True)
|
drop_remainder=True)
|
||||||
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
# estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
||||||
|
for batch_ix, batch in train_input:
|
||||||
|
output = model_fn(batch)
|
||||||
|
loss = output["loss"]
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if args.do_eval:
|
if args.do_eval:
|
||||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
eval_examples, label_list, args.max_seq_length, tokenizer)
|
eval_examples, label_list, args.max_seq_length, tokenizer)
|
||||||
|
|
||||||
tf.logging.info("***** Running evaluation *****")
|
logger.info("***** Running evaluation *****")
|
||||||
tf.logging.info(" Num examples = %d", len(eval_examples))
|
logger.info(" Num examples = %d", len(eval_examples))
|
||||||
tf.logging.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
|
||||||
# This tells the estimator to run through the entire set.
|
# This tells the estimator to run through the entire set.
|
||||||
eval_steps = None
|
eval_steps = None
|
||||||
@@ -564,10 +598,10 @@ def main(_):
|
|||||||
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
|
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
|
||||||
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
tf.logging.info("***** Eval results *****")
|
logger.info("***** Eval results *****")
|
||||||
for key in sorted(result.keys()):
|
for key in sorted(result.keys()):
|
||||||
tf.logging.info(" %s = %s", key, str(result[key]))
|
logger.info(" %s = %s", key, str(result[key]))
|
||||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user