updating squad for compatibility with XLNet
This commit is contained in:
@@ -41,7 +41,9 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
@@ -66,6 +68,8 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
@@ -118,10 +122,13 @@ def train(args, train_dataset, model, tokenizer):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
||||
'attention_mask': batch[2],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
ouputs = model(**inputs)
|
||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
@@ -197,31 +204,50 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
example_indices = batch[3]
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'attention_mask': batch[2]}
|
||||
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
||||
'attention_mask': batch[2]}
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
outputs = model(**inputs)
|
||||
batch_start_logits, batch_end_logits = outputs[:2]
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
start_logits = batch_start_logits[i].detach().cpu().tolist()
|
||||
end_logits = batch_end_logits[i].detach().cpu().tolist()
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
all_results.append(RawResult(unique_id=unique_id,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits))
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(unique_id = unique_id,
|
||||
start_top_log_probs = to_list(outputs[0][i]),
|
||||
start_top_index = to_list(outputs[1][i]),
|
||||
end_top_log_probs = to_list(outputs[2][i]),
|
||||
end_top_index = to_list(outputs[3][i]),
|
||||
cls_logits = to_list(outputs[4][i]))
|
||||
else:
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
all_results.append(result)
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length,
|
||||
args.do_lower_case, output_prediction_file, output_nbest_file,
|
||||
output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
args.start_n_top, args.end_n_top, args.version_2_with_negative)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
@@ -244,8 +270,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_squad_examples(input_file=input_file,
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
is_training=not evaluate,
|
||||
version_2_with_negative=args.version_2_with_negative)
|
||||
features = convert_examples_to_features(examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
@@ -260,13 +286,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
|
||||
Reference in New Issue
Block a user