New SQuAD API for distillation script
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
""" This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation."""
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
@@ -26,7 +25,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
@@ -46,22 +45,14 @@ from transformers import (
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
squad_convert_examples_to_features,
|
||||
)
|
||||
|
||||
from ..utils_squad import (
|
||||
RawResult,
|
||||
RawResultExtended,
|
||||
convert_examples_to_features,
|
||||
read_squad_examples,
|
||||
write_predictions,
|
||||
write_predictions_extended,
|
||||
from transformers.data.metrics.squad_metrics import (
|
||||
compute_predictions_log_probs,
|
||||
compute_predictions_logits,
|
||||
squad_evaluate,
|
||||
)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||
from ..utils_squad_evaluate import EVAL_OPTS
|
||||
from ..utils_squad_evaluate import main as evaluate_on_squad
|
||||
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
||||
|
||||
|
||||
try:
|
||||
@@ -69,7 +60,6 @@ try:
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
@@ -294,20 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(
|
||||
unique_id=unique_id,
|
||||
start_top_log_probs=to_list(outputs[0][i]),
|
||||
start_top_index=to_list(outputs[1][i]),
|
||||
end_top_log_probs=to_list(outputs[2][i]),
|
||||
end_top_index=to_list(outputs[3][i]),
|
||||
cls_logits=to_list(outputs[4][i]),
|
||||
|
||||
output = [to_list(output[i]) for output in outputs]
|
||||
|
||||
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
||||
# models only use two.
|
||||
if len(output) >= 5:
|
||||
start_logits = output[0]
|
||||
start_top_index = output[1]
|
||||
end_logits = output[2]
|
||||
end_top_index = output[3]
|
||||
cls_logits = output[4]
|
||||
|
||||
result = SquadResult(
|
||||
unique_id,
|
||||
start_logits,
|
||||
end_logits,
|
||||
start_top_index=start_top_index,
|
||||
end_top_index=end_top_index,
|
||||
cls_logits=cls_logits,
|
||||
)
|
||||
|
||||
else:
|
||||
result = RawResult(
|
||||
unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
|
||||
)
|
||||
start_logits, end_logits = output
|
||||
result = SquadResult(unique_id, start_logits, end_logits)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
# Compute predictions
|
||||
@@ -320,7 +321,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(
|
||||
predictions = compute_predictions_log_probs(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
@@ -337,7 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
args.verbose_logging,
|
||||
)
|
||||
else:
|
||||
write_predictions(
|
||||
predictions = compute_predictions_logits(
|
||||
examples,
|
||||
features,
|
||||
all_results,
|
||||
@@ -350,13 +351,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
args.verbose_logging,
|
||||
args.version_2_with_negative,
|
||||
args.null_score_diff_threshold,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(
|
||||
data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
|
||||
)
|
||||
results = evaluate_on_squad(evaluate_options)
|
||||
# Compute the F1 and exact scores.
|
||||
results = squad_evaluate(examples, predictions)
|
||||
return results
|
||||
|
||||
|
||||
@@ -368,59 +367,51 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
input_file = args.predict_file if evaluate else args.train_file
|
||||
cached_features_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
"cached_{}_{}_{}".format(
|
||||
"cached_distillation_{}_{}_{}".format(
|
||||
"dev" if evaluate else "train",
|
||||
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
||||
str(args.max_seq_length),
|
||||
),
|
||||
)
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features = torch.load(cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
|
||||
try:
|
||||
features, dataset, examples = (
|
||||
features_and_dataset["features"],
|
||||
features_and_dataset["dataset"],
|
||||
features_and_dataset["examples"],
|
||||
)
|
||||
except KeyError:
|
||||
raise DeprecationWarning(
|
||||
"You seem to be loading features from an older version of this script please delete the "
|
||||
"file %s in order for it to be created again" % cached_features_file
|
||||
)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_file)
|
||||
examples = read_squad_examples(
|
||||
input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
|
||||
)
|
||||
features = convert_examples_to_features(
|
||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||
if evaluate:
|
||||
examples = processor.get_dev_examples(None, filename=args.predict_file)
|
||||
else:
|
||||
examples = processor.get_train_examples(None, filename=args.train_file)
|
||||
|
||||
features, dataset = squad_convert_examples_to_features(
|
||||
examples=examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate,
|
||||
return_dataset="pt",
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
|
||||
)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids,
|
||||
all_input_mask,
|
||||
all_segment_ids,
|
||||
all_start_positions,
|
||||
all_end_positions,
|
||||
all_cls_index,
|
||||
all_p_mask,
|
||||
)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user