wip examples
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Run BERT on SQuAD."""
|
||||
""" Finetuning a question-answering Bert model on SQuAD."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Run BERT on SQuAD."""
|
||||
""" Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@@ -21,7 +21,6 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
@@ -33,31 +32,35 @@ from tqdm import tqdm, trange
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering,
|
||||
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
||||
XLMTokenizer)
|
||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer,
|
||||
XLMConfig, XLMForQuestionAnswering,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetTokenizer)
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': BertForQuestionAnswering,
|
||||
'xlnet': XLNetForQuestionAnswering,
|
||||
'xlm': XLMForQuestionAnswering,
|
||||
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||
}
|
||||
|
||||
TOKENIZER_CLASSES = {
|
||||
'bert': BertTokenizer,
|
||||
'xlnet': XLNetTokenizer,
|
||||
'xlm': XLMTokenizer,
|
||||
}
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model):
|
||||
""" Train the model """
|
||||
|
||||
Reference in New Issue
Block a user