wip examples
This commit is contained in:
@@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Run BERT on SQuAD."""
|
""" Finetuning a question-answering Bert model on SQuAD."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT finetuning runner."""
|
""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Run BERT on SQuAD."""
|
""" Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
@@ -21,7 +21,6 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -33,31 +32,35 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering,
|
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||||
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertForQuestionAnswering, BertTokenizer,
|
||||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLMConfig, XLMForQuestionAnswering,
|
||||||
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
|
XLMTokenizer, XLNetConfig,
|
||||||
XLMTokenizer)
|
XLNetForQuestionAnswering,
|
||||||
|
XLNetTokenizer)
|
||||||
|
|
||||||
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
|
||||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'bert': BertForQuestionAnswering,
|
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||||
'xlnet': XLNetForQuestionAnswering,
|
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
|
||||||
'xlm': XLMForQuestionAnswering,
|
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
TOKENIZER_CLASSES = {
|
def set_seed(args):
|
||||||
'bert': BertTokenizer,
|
random.seed(args.seed)
|
||||||
'xlnet': XLNetTokenizer,
|
np.random.seed(args.seed)
|
||||||
'xlm': XLMTokenizer,
|
torch.manual_seed(args.seed)
|
||||||
}
|
if args.n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model):
|
def train(args, train_dataset, model):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
|
|||||||
Reference in New Issue
Block a user