From 762ded9b1c92f9cef2aa08c907e0f9b11b43e37a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 12 Jul 2019 11:28:52 +0200 Subject: [PATCH] wip examples --- examples/run_bert_squad.py | 2 +- examples/run_glue.py | 2 +- examples/run_squad.py | 39 ++++++++++++++++++++------------------ 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/examples/run_bert_squad.py b/examples/run_bert_squad.py index c3fdb06316..e5ba1b3b95 100644 --- a/examples/run_bert_squad.py +++ b/examples/run_bert_squad.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Run BERT on SQuAD.""" +""" Finetuning a question-answering Bert model on SQuAD.""" from __future__ import absolute_import, division, print_function diff --git a/examples/run_glue.py b/examples/run_glue.py index 7e615804c1..6f96a23476 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""BERT finetuning runner.""" +""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE.""" from __future__ import absolute_import, division, print_function diff --git a/examples/run_squad.py b/examples/run_squad.py index 7f063109e3..3d3d964687 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Run BERT on SQuAD.""" +""" Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD.""" from __future__ import absolute_import, division, print_function @@ -21,7 +21,6 @@ import argparse import logging import os import random -import sys from io import open import numpy as np @@ -33,31 +32,35 @@ from tqdm import tqdm, trange from tensorboardX import SummaryWriter -from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering, - XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, - XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) -from pytorch_transformers import (BertTokenizer, XLNetTokenizer, - XLMTokenizer) +from pytorch_transformers import (WEIGHTS_NAME, BertConfig, + BertForQuestionAnswering, BertTokenizer, + XLMConfig, XLMForQuestionAnswering, + XLMTokenizer, XLNetConfig, + XLNetForQuestionAnswering, + XLNetTokenizer) + +from pytorch_transformers import AdamW, WarmupLinearSchedule from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions logger = logging.getLogger(__name__) -ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP, - XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, - XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ()) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ + for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) MODEL_CLASSES = { - 'bert': BertForQuestionAnswering, - 'xlnet': XLNetForQuestionAnswering, - 'xlm': XLMForQuestionAnswering, + 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), + 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), + 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), } -TOKENIZER_CLASSES = { - 'bert': BertTokenizer, - 'xlnet': XLNetTokenizer, - 'xlm': XLMTokenizer, -} +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + def train(args, train_dataset, model): """ Train the model """