wip examples

This commit is contained in:
thomwolf
2019-07-12 11:28:52 +02:00
parent 7442956361
commit 762ded9b1c
3 changed files with 23 additions and 20 deletions

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Run BERT on SQuAD.""" """ Finetuning a question-answering Bert model on SQuAD."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """ Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Run BERT on SQuAD.""" """ Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
@@ -21,7 +21,6 @@ import argparse
import logging import logging
import os import os
import random import random
import sys
from io import open from io import open
import numpy as np import numpy as np
@@ -33,31 +32,35 @@ from tqdm import tqdm, trange
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering, from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering, BertTokenizer,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) XLMConfig, XLMForQuestionAnswering,
from pytorch_transformers import (BertTokenizer, XLNetTokenizer, XLMTokenizer, XLNetConfig,
XLMTokenizer) XLNetForQuestionAnswering,
XLNetTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': BertForQuestionAnswering, 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': XLNetForQuestionAnswering, 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': XLMForQuestionAnswering, 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
} }
TOKENIZER_CLASSES = { def set_seed(args):
'bert': BertTokenizer, random.seed(args.seed)
'xlnet': XLNetTokenizer, np.random.seed(args.seed)
'xlm': XLMTokenizer, torch.manual_seed(args.seed)
} if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train(args, train_dataset, model): def train(args, train_dataset, model):
""" Train the model """ """ Train the model """