From 164c794eb356917237512a9755f26f3caf0a6255 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 8 Jan 2020 16:33:23 +0100 Subject: [PATCH] New SQuAD API for distillation script --- .../distillation/run_squad_w_distillation.py | 127 ++++++++---------- 1 file changed, 59 insertions(+), 68 deletions(-) diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index e5a2265ed6..91d3802f6b 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -15,7 +15,6 @@ # limitations under the License. """ This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation.""" - import argparse import glob import logging @@ -26,7 +25,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange @@ -46,22 +45,14 @@ from transformers import ( XLNetForQuestionAnswering, XLNetTokenizer, get_linear_schedule_with_warmup, + squad_convert_examples_to_features, ) - -from ..utils_squad import ( - RawResult, - RawResultExtended, - convert_examples_to_features, - read_squad_examples, - write_predictions, - write_predictions_extended, +from transformers.data.metrics.squad_metrics import ( + compute_predictions_log_probs, + compute_predictions_logits, + squad_evaluate, ) - -# The follwing import is the official SQuAD evaluation script (2.0). -# You can remove it from the dependencies if you are using this script outside of the library -# We've added it here for automated tests (see examples/test_examples.py file) -from ..utils_squad_evaluate import EVAL_OPTS -from ..utils_squad_evaluate import main as evaluate_on_squad +from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor try: @@ -69,7 +60,6 @@ try: except ImportError: from tensorboardX import SummaryWriter - logger = logging.getLogger(__name__) ALL_MODELS = sum( @@ -294,20 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""): for i, example_index in enumerate(example_indices): eval_feature = features[example_index.item()] unique_id = int(eval_feature.unique_id) - if args.model_type in ["xlnet", "xlm"]: - # XLNet uses a more complex post-processing procedure - result = RawResultExtended( - unique_id=unique_id, - start_top_log_probs=to_list(outputs[0][i]), - start_top_index=to_list(outputs[1][i]), - end_top_log_probs=to_list(outputs[2][i]), - end_top_index=to_list(outputs[3][i]), - cls_logits=to_list(outputs[4][i]), + + output = [to_list(output[i]) for output in outputs] + + # Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler" + # models only use two. + if len(output) >= 5: + start_logits = output[0] + start_top_index = output[1] + end_logits = output[2] + end_top_index = output[3] + cls_logits = output[4] + + result = SquadResult( + unique_id, + start_logits, + end_logits, + start_top_index=start_top_index, + end_top_index=end_top_index, + cls_logits=cls_logits, ) + else: - result = RawResult( - unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i]) - ) + start_logits, end_logits = output + result = SquadResult(unique_id, start_logits, end_logits) + all_results.append(result) # Compute predictions @@ -320,7 +321,7 @@ def evaluate(args, model, tokenizer, prefix=""): if args.model_type in ["xlnet", "xlm"]: # XLNet uses a more complex post-processing procedure - write_predictions_extended( + predictions = compute_predictions_log_probs( examples, features, all_results, @@ -337,7 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""): args.verbose_logging, ) else: - write_predictions( + predictions = compute_predictions_logits( examples, features, all_results, @@ -350,13 +351,11 @@ def evaluate(args, model, tokenizer, prefix=""): args.verbose_logging, args.version_2_with_negative, args.null_score_diff_threshold, + tokenizer, ) - # Evaluate with the official SQuAD script - evaluate_options = EVAL_OPTS( - data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file - ) - results = evaluate_on_squad(evaluate_options) + # Compute the F1 and exact scores. + results = squad_evaluate(examples, predictions) return results @@ -368,59 +367,51 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal input_file = args.predict_file if evaluate else args.train_file cached_features_file = os.path.join( os.path.dirname(input_file), - "cached_{}_{}_{}".format( + "cached_distillation_{}_{}_{}".format( "dev" if evaluate else "train", list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), ), ) - if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: + if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) - features = torch.load(cached_features_file) + features_and_dataset = torch.load(cached_features_file) + + try: + features, dataset, examples = ( + features_and_dataset["features"], + features_and_dataset["dataset"], + features_and_dataset["examples"], + ) + except KeyError: + raise DeprecationWarning( + "You seem to be loading features from an older version of this script please delete the " + "file %s in order for it to be created again" % cached_features_file + ) else: logger.info("Creating features from dataset file at %s", input_file) - examples = read_squad_examples( - input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative - ) - features = convert_examples_to_features( + processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() + if evaluate: + examples = processor.get_dev_examples(None, filename=args.predict_file) + else: + examples = processor.get_train_examples(None, filename=args.train_file) + + features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, + return_dataset="pt", ) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) - torch.save(features, cached_features_file) + torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) if args.local_rank == 0 and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache - # Convert to Tensors and build dataset - all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) - all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) - all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) - all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) - all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) - if evaluate: - all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) - dataset = TensorDataset( - all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask - ) - else: - all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) - all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) - dataset = TensorDataset( - all_input_ids, - all_input_mask, - all_segment_ids, - all_start_positions, - all_end_positions, - all_cls_index, - all_p_mask, - ) - if output_examples: return dataset, examples, features return dataset