updating examples and doc
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD."""
|
||||
""" Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@@ -21,7 +21,7 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from io import open
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -43,6 +43,9 @@ from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -123,7 +126,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
@@ -169,6 +172,9 @@ def train(args, train_dataset, model, tokenizer):
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
@@ -208,16 +214,16 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits))
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
all_predictions = write_predictions(examples, features, all_results,
|
||||
args.n_best_size, args.max_answer_length,
|
||||
args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file,
|
||||
args.verbose_logging, args.version_2_with_negative,
|
||||
args.null_score_diff_threshold)
|
||||
write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length,
|
||||
args.do_lower_case, output_prediction_file, output_nbest_file,
|
||||
output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
||||
pred_file=output_prediction_file,
|
||||
na_prob_file=output_null_log_odds_file)
|
||||
@@ -432,7 +438,7 @@ def main():
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
# Save the trained model and the tokenizer
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
@@ -454,22 +460,30 @@ def main():
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user