resolving merge conflicts
This commit is contained in:
@@ -46,7 +46,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class SquadExample(object):
|
class SquadExample(object):
|
||||||
"""A single training/test example for the Squad dataset."""
|
"""
|
||||||
|
A single training/test example for the Squad dataset.
|
||||||
|
For examples without an answer, the start and end position are -1.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
qas_id,
|
qas_id,
|
||||||
@@ -54,13 +57,15 @@ class SquadExample(object):
|
|||||||
doc_tokens,
|
doc_tokens,
|
||||||
orig_answer_text=None,
|
orig_answer_text=None,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None):
|
end_position=None,
|
||||||
|
is_impossible=None):
|
||||||
self.qas_id = qas_id
|
self.qas_id = qas_id
|
||||||
self.question_text = question_text
|
self.question_text = question_text
|
||||||
self.doc_tokens = doc_tokens
|
self.doc_tokens = doc_tokens
|
||||||
self.orig_answer_text = orig_answer_text
|
self.orig_answer_text = orig_answer_text
|
||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
|
self.is_impossible = is_impossible
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
@@ -75,6 +80,8 @@ class SquadExample(object):
|
|||||||
s += ", start_position: %d" % (self.start_position)
|
s += ", start_position: %d" % (self.start_position)
|
||||||
if self.start_position:
|
if self.start_position:
|
||||||
s += ", end_position: %d" % (self.end_position)
|
s += ", end_position: %d" % (self.end_position)
|
||||||
|
if self.start_position:
|
||||||
|
s += ", is_impossible: %r" % (self.is_impossible)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
@@ -92,7 +99,8 @@ class InputFeatures(object):
|
|||||||
input_mask,
|
input_mask,
|
||||||
segment_ids,
|
segment_ids,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None):
|
end_position=None,
|
||||||
|
is_impossible=None):
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.example_index = example_index
|
self.example_index = example_index
|
||||||
self.doc_span_index = doc_span_index
|
self.doc_span_index = doc_span_index
|
||||||
@@ -104,9 +112,10 @@ class InputFeatures(object):
|
|||||||
self.segment_ids = segment_ids
|
self.segment_ids = segment_ids
|
||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
|
self.is_impossible = is_impossible
|
||||||
|
|
||||||
|
|
||||||
def read_squad_examples(input_file, is_training):
|
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||||
"""Read a SQuAD json file into a list of SquadExample."""
|
"""Read a SQuAD json file into a list of SquadExample."""
|
||||||
with open(input_file, "r", encoding='utf-8') as reader:
|
with open(input_file, "r", encoding='utf-8') as reader:
|
||||||
input_data = json.load(reader)["data"]
|
input_data = json.load(reader)["data"]
|
||||||
@@ -140,10 +149,14 @@ def read_squad_examples(input_file, is_training):
|
|||||||
start_position = None
|
start_position = None
|
||||||
end_position = None
|
end_position = None
|
||||||
orig_answer_text = None
|
orig_answer_text = None
|
||||||
|
is_impossible = False
|
||||||
if is_training:
|
if is_training:
|
||||||
if len(qa["answers"]) != 1:
|
if version_2_with_negative:
|
||||||
|
is_impossible = qa["is_impossible"]
|
||||||
|
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For training, each question should have exactly 1 answer.")
|
"For training, each question should have exactly 1 answer.")
|
||||||
|
if not is_impossible:
|
||||||
answer = qa["answers"][0]
|
answer = qa["answers"][0]
|
||||||
orig_answer_text = answer["text"]
|
orig_answer_text = answer["text"]
|
||||||
answer_offset = answer["answer_start"]
|
answer_offset = answer["answer_start"]
|
||||||
@@ -163,6 +176,10 @@ def read_squad_examples(input_file, is_training):
|
|||||||
logger.warning("Could not find answer: '%s' vs. '%s'",
|
logger.warning("Could not find answer: '%s' vs. '%s'",
|
||||||
actual_text, cleaned_answer_text)
|
actual_text, cleaned_answer_text)
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
start_position = -1
|
||||||
|
end_position = -1
|
||||||
|
orig_answer_text = ""
|
||||||
|
|
||||||
example = SquadExample(
|
example = SquadExample(
|
||||||
qas_id=qas_id,
|
qas_id=qas_id,
|
||||||
@@ -170,7 +187,8 @@ def read_squad_examples(input_file, is_training):
|
|||||||
doc_tokens=doc_tokens,
|
doc_tokens=doc_tokens,
|
||||||
orig_answer_text=orig_answer_text,
|
orig_answer_text=orig_answer_text,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position)
|
end_position=end_position,
|
||||||
|
is_impossible=is_impossible)
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -200,7 +218,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
tok_start_position = None
|
tok_start_position = None
|
||||||
tok_end_position = None
|
tok_end_position = None
|
||||||
if is_training:
|
if is_training and example.is_impossible:
|
||||||
|
tok_start_position = -1
|
||||||
|
tok_end_position = -1
|
||||||
|
if is_training and not example.is_impossible:
|
||||||
tok_start_position = orig_to_tok_index[example.start_position]
|
tok_start_position = orig_to_tok_index[example.start_position]
|
||||||
if example.end_position < len(example.doc_tokens) - 1:
|
if example.end_position < len(example.doc_tokens) - 1:
|
||||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||||
@@ -272,20 +293,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
start_position = None
|
start_position = None
|
||||||
end_position = None
|
end_position = None
|
||||||
if is_training:
|
if is_training and not example.is_impossible:
|
||||||
# For training, if our document chunk does not contain an annotation
|
# For training, if our document chunk does not contain an annotation
|
||||||
# we throw it out, since there is nothing to predict.
|
# we throw it out, since there is nothing to predict.
|
||||||
doc_start = doc_span.start
|
doc_start = doc_span.start
|
||||||
doc_end = doc_span.start + doc_span.length - 1
|
doc_end = doc_span.start + doc_span.length - 1
|
||||||
if (example.start_position < doc_start or
|
out_of_span = False
|
||||||
example.end_position < doc_start or
|
if not (tok_start_position >= doc_start and
|
||||||
example.start_position > doc_end or example.end_position > doc_end):
|
tok_end_position <= doc_end):
|
||||||
continue
|
out_of_span = True
|
||||||
|
if out_of_span:
|
||||||
|
start_position = 0
|
||||||
|
end_position = 0
|
||||||
|
else:
|
||||||
doc_offset = len(query_tokens) + 2
|
doc_offset = len(query_tokens) + 2
|
||||||
start_position = tok_start_position - doc_start + doc_offset
|
start_position = tok_start_position - doc_start + doc_offset
|
||||||
end_position = tok_end_position - doc_start + doc_offset
|
end_position = tok_end_position - doc_start + doc_offset
|
||||||
|
if is_training and example.is_impossible:
|
||||||
|
start_position = 0
|
||||||
|
end_position = 0
|
||||||
if example_index < 20:
|
if example_index < 20:
|
||||||
logger.info("*** Example ***")
|
logger.info("*** Example ***")
|
||||||
logger.info("unique_id: %s" % (unique_id))
|
logger.info("unique_id: %s" % (unique_id))
|
||||||
@@ -302,7 +328,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||||
if is_training:
|
if is_training and example.is_impossible:
|
||||||
|
logger.info("impossible example")
|
||||||
|
if is_training and not example.is_impossible:
|
||||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||||
logger.info("start_position: %d" % (start_position))
|
logger.info("start_position: %d" % (start_position))
|
||||||
logger.info("end_position: %d" % (end_position))
|
logger.info("end_position: %d" % (end_position))
|
||||||
@@ -321,7 +349,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
segment_ids=segment_ids,
|
segment_ids=segment_ids,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position))
|
end_position=end_position,
|
||||||
|
is_impossible=example.is_impossible))
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
|
||||||
return features
|
return features
|
||||||
@@ -401,15 +430,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
|||||||
return cur_span_index == best_span_index
|
return cur_span_index == best_span_index
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
RawResult = collections.namedtuple("RawResult",
|
RawResult = collections.namedtuple("RawResult",
|
||||||
["unique_id", "start_logits", "end_logits"])
|
["unique_id", "start_logits", "end_logits"])
|
||||||
|
|
||||||
|
|
||||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||||
max_answer_length, do_lower_case, output_prediction_file,
|
max_answer_length, do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, verbose_logging):
|
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||||
"""Write final predictions to the json file."""
|
version_2_with_negative, null_score_diff_threshold):
|
||||||
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
|
|
||||||
@@ -427,15 +456,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
|
|
||||||
all_predictions = collections.OrderedDict()
|
all_predictions = collections.OrderedDict()
|
||||||
all_nbest_json = collections.OrderedDict()
|
all_nbest_json = collections.OrderedDict()
|
||||||
|
scores_diff_json = collections.OrderedDict()
|
||||||
|
|
||||||
for (example_index, example) in enumerate(all_examples):
|
for (example_index, example) in enumerate(all_examples):
|
||||||
features = example_index_to_features[example_index]
|
features = example_index_to_features[example_index]
|
||||||
|
|
||||||
prelim_predictions = []
|
prelim_predictions = []
|
||||||
|
# keep track of the minimum score of null start+end of position 0
|
||||||
|
score_null = 1000000 # large and positive
|
||||||
|
min_null_feature_index = 0 # the paragraph slice with min mull score
|
||||||
|
null_start_logit = 0 # the start logit at the slice with min null score
|
||||||
|
null_end_logit = 0 # the end logit at the slice with min null score
|
||||||
for (feature_index, feature) in enumerate(features):
|
for (feature_index, feature) in enumerate(features):
|
||||||
result = unique_id_to_result[feature.unique_id]
|
result = unique_id_to_result[feature.unique_id]
|
||||||
|
|
||||||
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
||||||
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
||||||
|
# if we could have irrelevant answers, get the min score of irrelevant
|
||||||
|
if version_2_with_negative:
|
||||||
|
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
||||||
|
if feature_null_score < score_null:
|
||||||
|
score_null = feature_null_score
|
||||||
|
min_null_feature_index = feature_index
|
||||||
|
null_start_logit = result.start_logits[0]
|
||||||
|
null_end_logit = result.end_logits[0]
|
||||||
for start_index in start_indexes:
|
for start_index in start_indexes:
|
||||||
for end_index in end_indexes:
|
for end_index in end_indexes:
|
||||||
# We could hypothetically create invalid predictions, e.g., predict
|
# We could hypothetically create invalid predictions, e.g., predict
|
||||||
@@ -463,7 +506,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
end_index=end_index,
|
end_index=end_index,
|
||||||
start_logit=result.start_logits[start_index],
|
start_logit=result.start_logits[start_index],
|
||||||
end_logit=result.end_logits[end_index]))
|
end_logit=result.end_logits[end_index]))
|
||||||
|
if version_2_with_negative:
|
||||||
|
prelim_predictions.append(
|
||||||
|
_PrelimPrediction(
|
||||||
|
feature_index=min_null_feature_index,
|
||||||
|
start_index=0,
|
||||||
|
end_index=0,
|
||||||
|
start_logit=null_start_logit,
|
||||||
|
end_logit=null_end_logit))
|
||||||
prelim_predictions = sorted(
|
prelim_predictions = sorted(
|
||||||
prelim_predictions,
|
prelim_predictions,
|
||||||
key=lambda x: (x.start_logit + x.end_logit),
|
key=lambda x: (x.start_logit + x.end_logit),
|
||||||
@@ -478,7 +528,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
if len(nbest) >= n_best_size:
|
if len(nbest) >= n_best_size:
|
||||||
break
|
break
|
||||||
feature = features[pred.feature_index]
|
feature = features[pred.feature_index]
|
||||||
|
if pred.start_index > 0: # this is a non-null prediction
|
||||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||||
@@ -499,12 +549,23 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
seen_predictions[final_text] = True
|
seen_predictions[final_text] = True
|
||||||
|
else:
|
||||||
|
final_text = ""
|
||||||
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(
|
||||||
_NbestPrediction(
|
_NbestPrediction(
|
||||||
text=final_text,
|
text=final_text,
|
||||||
start_logit=pred.start_logit,
|
start_logit=pred.start_logit,
|
||||||
end_logit=pred.end_logit))
|
end_logit=pred.end_logit))
|
||||||
|
# if we didn't include the empty option in the n-best, include it
|
||||||
|
if version_2_with_negative:
|
||||||
|
if "" not in seen_predictions:
|
||||||
|
nbest.append(
|
||||||
|
_NbestPrediction(
|
||||||
|
text="",
|
||||||
|
start_logit=null_start_logit,
|
||||||
|
end_logit=null_end_logit))
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if not nbest:
|
if not nbest:
|
||||||
@@ -514,8 +575,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
assert len(nbest) >= 1
|
assert len(nbest) >= 1
|
||||||
|
|
||||||
total_scores = []
|
total_scores = []
|
||||||
|
best_non_null_entry = None
|
||||||
for entry in nbest:
|
for entry in nbest:
|
||||||
total_scores.append(entry.start_logit + entry.end_logit)
|
total_scores.append(entry.start_logit + entry.end_logit)
|
||||||
|
if not best_non_null_entry:
|
||||||
|
if entry.text:
|
||||||
|
best_non_null_entry = entry
|
||||||
|
|
||||||
probs = _compute_softmax(total_scores)
|
probs = _compute_softmax(total_scores)
|
||||||
|
|
||||||
@@ -530,7 +595,17 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
|
|
||||||
assert len(nbest_json) >= 1
|
assert len(nbest_json) >= 1
|
||||||
|
|
||||||
|
if not version_2_with_negative:
|
||||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||||
|
else:
|
||||||
|
# predict "" iff the null score - the score of best non-null > threshold
|
||||||
|
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||||
|
best_non_null_entry.end_logit)
|
||||||
|
scores_diff_json[example.qas_id] = score_diff
|
||||||
|
if score_diff > null_score_diff_threshold:
|
||||||
|
all_predictions[example.qas_id] = ""
|
||||||
|
else:
|
||||||
|
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||||
all_nbest_json[example.qas_id] = nbest_json
|
all_nbest_json[example.qas_id] = nbest_json
|
||||||
|
|
||||||
with open(output_prediction_file, "w") as writer:
|
with open(output_prediction_file, "w") as writer:
|
||||||
@@ -539,6 +614,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
with open(output_nbest_file, "w") as writer:
|
with open(output_nbest_file, "w") as writer:
|
||||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
if version_2_with_negative:
|
||||||
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||||
"""Project the tokenized prediction back to the original text."""
|
"""Project the tokenized prediction back to the original text."""
|
||||||
@@ -701,7 +780,7 @@ def main():
|
|||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
|
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||||
"of training.")
|
"of training.")
|
||||||
parser.add_argument("--n_best_size", default=20, type=int,
|
parser.add_argument("--n_best_size", default=20, type=int,
|
||||||
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
||||||
@@ -738,7 +817,12 @@ def main():
|
|||||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||||
"0 (default value): dynamic loss scaling.\n"
|
"0 (default value): dynamic loss scaling.\n"
|
||||||
"Positive power of 2: static loss scaling value.\n")
|
"Positive power of 2: static loss scaling value.\n")
|
||||||
|
parser.add_argument('--version_2_with_negative',
|
||||||
|
action='store_true',
|
||||||
|
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||||
|
parser.add_argument('--null_score_diff_threshold',
|
||||||
|
type=float, default=0.0,
|
||||||
|
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.local_rank == -1 or args.no_cuda:
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
@@ -787,9 +871,9 @@ def main():
|
|||||||
num_train_optimization_steps = None
|
num_train_optimization_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_examples = read_squad_examples(
|
train_examples = read_squad_examples(
|
||||||
input_file=args.train_file, is_training=True)
|
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
||||||
num_train_optimization_steps = int(
|
num_train_optimization_steps = int(
|
||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
@@ -825,7 +909,7 @@ def main():
|
|||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FP16_Optimizer
|
from apex.optimizer import FP16_Optimizer
|
||||||
from apex.optimizers import FusedAdam
|
from apex.optimizers import FusedAdam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
@@ -901,7 +985,7 @@ def main():
|
|||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
# modify learning rate with special warm up BERT uses
|
# modify learning rate with special warm up BERT uses
|
||||||
# if args.fp16 is False, BertAdam is used that handles this automatically
|
# if args.fp16 is False, BertAdam is used and handles this automatically
|
||||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr_this_step
|
param_group['lr'] = lr_this_step
|
||||||
@@ -914,7 +998,6 @@ def main():
|
|||||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
|
||||||
# Load a trained model that you have fine-tuned
|
# Load a trained model that you have fine-tuned
|
||||||
model_state_dict = torch.load(output_model_file)
|
model_state_dict = torch.load(output_model_file)
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
@@ -925,7 +1008,7 @@ def main():
|
|||||||
|
|
||||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
eval_examples = read_squad_examples(
|
eval_examples = read_squad_examples(
|
||||||
input_file=args.predict_file, is_training=False)
|
input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
examples=eval_examples,
|
examples=eval_examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -969,10 +1052,12 @@ def main():
|
|||||||
end_logits=end_logits))
|
end_logits=end_logits))
|
||||||
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
||||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
||||||
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
|
||||||
write_predictions(eval_examples, eval_features, all_results,
|
write_predictions(eval_examples, eval_features, all_results,
|
||||||
args.n_best_size, args.max_answer_length,
|
args.n_best_size, args.max_answer_length,
|
||||||
args.do_lower_case, output_prediction_file,
|
args.do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, args.verbose_logging)
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
|
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user