Merge branch 'master' into python_2
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -121,5 +121,5 @@ dmypy.json
|
|||||||
# TF code
|
# TF code
|
||||||
tensorflow_code
|
tensorflow_code
|
||||||
|
|
||||||
# models
|
# Models
|
||||||
models
|
models
|
||||||
18
README.md
18
README.md
@@ -53,14 +53,14 @@ python -m pytest -sv tests/
|
|||||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
||||||
|
|
||||||
- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||||
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**),
|
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L556) - raw BERT Transformer model (**fully pre-trained**),
|
||||||
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L710) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L771) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L639) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L833) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||||
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L899) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
|
||||||
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L969) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
|
||||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1034) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||||
|
|
||||||
- Three PyTorch models (`torch.nn.Module`) for OpenAI with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
|
- Three PyTorch models (`torch.nn.Module`) for OpenAI with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
|
||||||
- [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**),
|
- [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**),
|
||||||
@@ -94,7 +94,7 @@ The repository further comprises:
|
|||||||
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
|
||||||
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
|
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
|
||||||
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
|
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
|
||||||
- [`run_lm_finetuning`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining' on a target text corpus.
|
- [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining' on a target text corpus.
|
||||||
|
|
||||||
These examples are detailed in the [Examples](#examples) section of this readme.
|
These examples are detailed in the [Examples](#examples) section of this readme.
|
||||||
|
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -299,11 +299,6 @@ def accuracy(out, labels):
|
|||||||
outputs = np.argmax(out, axis=1)
|
outputs = np.argmax(out, axis=1)
|
||||||
return np.sum(outputs == labels)
|
return np.sum(outputs == labels)
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
|
||||||
if x < warmup:
|
|
||||||
return x/warmup
|
|
||||||
return 1.0 - x
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -419,7 +414,7 @@ def main():
|
|||||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
args.gradient_accumulation_steps))
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||||
|
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -447,11 +442,13 @@ def main():
|
|||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_optimization_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_examples = processor.get_train_examples(args.data_dir)
|
train_examples = processor.get_train_examples(args.data_dir)
|
||||||
num_train_steps = int(
|
num_train_optimization_steps = int(
|
||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
|
if args.local_rank != -1:
|
||||||
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||||
@@ -477,9 +474,6 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
t_total = num_train_steps
|
|
||||||
if args.local_rank != -1:
|
|
||||||
t_total = t_total // torch.distributed.get_world_size()
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FP16_Optimizer
|
from apex.optimizers import FP16_Optimizer
|
||||||
@@ -500,7 +494,7 @@ def main():
|
|||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=t_total)
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
nb_tr_steps = 0
|
nb_tr_steps = 0
|
||||||
@@ -511,7 +505,7 @@ def main():
|
|||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_examples))
|
logger.info(" Num examples = %d", len(train_examples))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||||
@@ -545,10 +539,12 @@ def main():
|
|||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
# modify learning rate with special warm up BERT uses
|
if args.fp16:
|
||||||
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
|
# modify learning rate with special warm up BERT uses
|
||||||
for param_group in optimizer.param_groups:
|
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||||
param_group['lr'] = lr_this_step
|
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr_this_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|||||||
@@ -30,8 +30,11 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import random
|
||||||
|
|
||||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||||||
@@ -39,12 +42,6 @@ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
|
||||||
if x < warmup:
|
|
||||||
return x/warmup
|
|
||||||
return 1.0 - x
|
|
||||||
|
|
||||||
|
|
||||||
class BERTDataset(Dataset):
|
class BERTDataset(Dataset):
|
||||||
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
|
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
|
||||||
self.vocab = tokenizer.vocab
|
self.vocab = tokenizer.vocab
|
||||||
@@ -136,11 +133,11 @@ class BERTDataset(Dataset):
|
|||||||
# transform sample to features
|
# transform sample to features
|
||||||
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
|
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
|
||||||
|
|
||||||
cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids),
|
cur_tensors = (torch.tensor(cur_features.input_ids),
|
||||||
"input_mask": torch.tensor(cur_features.input_mask),
|
torch.tensor(cur_features.input_mask),
|
||||||
"segment_ids": torch.tensor(cur_features.segment_ids),
|
torch.tensor(cur_features.segment_ids),
|
||||||
"lm_label_ids": torch.tensor(cur_features.lm_label_ids),
|
torch.tensor(cur_features.lm_label_ids),
|
||||||
"is_next": torch.tensor(cur_features.is_next)}
|
torch.tensor(cur_features.is_next))
|
||||||
|
|
||||||
return cur_tensors
|
return cur_tensors
|
||||||
|
|
||||||
@@ -325,8 +322,8 @@ def convert_example_to_features(example, max_seq_length, tokenizer):
|
|||||||
# Account for [CLS], [SEP], [SEP] with "- 3"
|
# Account for [CLS], [SEP], [SEP] with "- 3"
|
||||||
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
||||||
|
|
||||||
t1_random, t1_label = random_word(tokens_a, tokenizer)
|
tokens_a, t1_label = random_word(tokens_a, tokenizer)
|
||||||
t2_random, t2_label = random_word(tokens_b, tokenizer)
|
tokens_b, t2_label = random_word(tokens_b, tokenizer)
|
||||||
# concatenate lm labels and account for CLS, SEP, SEP
|
# concatenate lm labels and account for CLS, SEP, SEP
|
||||||
lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
|
lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
|
||||||
|
|
||||||
@@ -459,6 +456,9 @@ def main():
|
|||||||
parser.add_argument("--on_memory",
|
parser.add_argument("--on_memory",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether to load train samples into memory or use disk")
|
help="Whether to load train samples into memory or use disk")
|
||||||
|
parser.add_argument("--do_lower_case",
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
@@ -498,7 +498,7 @@ def main():
|
|||||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
args.gradient_accumulation_steps))
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||||
|
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -517,13 +517,15 @@ def main():
|
|||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
#train_examples = None
|
#train_examples = None
|
||||||
num_train_steps = None
|
num_train_optimization_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
print("Loading Train Dataset", args.train_file)
|
print("Loading Train Dataset", args.train_file)
|
||||||
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
|
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
|
||||||
corpus_lines=None, on_memory=args.on_memory)
|
corpus_lines=None, on_memory=args.on_memory)
|
||||||
num_train_steps = int(
|
num_train_optimization_steps = int(
|
||||||
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
|
if args.local_rank != -1:
|
||||||
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForPreTraining.from_pretrained(args.bert_model)
|
model = BertForPreTraining.from_pretrained(args.bert_model)
|
||||||
@@ -546,6 +548,7 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FP16_Optimizer
|
from apex.optimizers import FP16_Optimizer
|
||||||
@@ -566,14 +569,14 @@ def main():
|
|||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_dataset))
|
logger.info(" Num examples = %d", len(train_dataset))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||||
|
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
train_sampler = RandomSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset)
|
||||||
@@ -588,7 +591,7 @@ def main():
|
|||||||
tr_loss = 0
|
tr_loss = 0
|
||||||
nb_tr_examples, nb_tr_steps = 0, 0
|
nb_tr_examples, nb_tr_steps = 0, 0
|
||||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
batch = tuple(t.to(device) for t in batch.values())
|
batch = tuple(t.to(device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
|
||||||
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
@@ -603,20 +606,22 @@ def main():
|
|||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
nb_tr_steps += 1
|
nb_tr_steps += 1
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
# modify learning rate with special warm up BERT uses
|
if args.fp16:
|
||||||
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion)
|
# modify learning rate with special warm up BERT uses
|
||||||
for param_group in optimizer.param_groups:
|
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||||
param_group['lr'] = lr_this_step
|
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr_this_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
# Save a trained model
|
||||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
if n_gpu > 1:
|
if args.do_train:
|
||||||
torch.save(model.module.bert.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
else:
|
|
||||||
torch.save(model.bert.state_dict(), output_model_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
whitespace_tokenize)
|
whitespace_tokenize)
|
||||||
@@ -53,7 +53,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class SquadExample(object):
|
class SquadExample(object):
|
||||||
"""A single training/test example for the Squad dataset."""
|
"""
|
||||||
|
A single training/test example for the Squad dataset.
|
||||||
|
For examples without an answer, the start and end position are -1.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
qas_id,
|
qas_id,
|
||||||
@@ -61,13 +64,15 @@ class SquadExample(object):
|
|||||||
doc_tokens,
|
doc_tokens,
|
||||||
orig_answer_text=None,
|
orig_answer_text=None,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None):
|
end_position=None,
|
||||||
|
is_impossible=None):
|
||||||
self.qas_id = qas_id
|
self.qas_id = qas_id
|
||||||
self.question_text = question_text
|
self.question_text = question_text
|
||||||
self.doc_tokens = doc_tokens
|
self.doc_tokens = doc_tokens
|
||||||
self.orig_answer_text = orig_answer_text
|
self.orig_answer_text = orig_answer_text
|
||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
|
self.is_impossible = is_impossible
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
@@ -82,6 +87,8 @@ class SquadExample(object):
|
|||||||
s += ", start_position: %d" % (self.start_position)
|
s += ", start_position: %d" % (self.start_position)
|
||||||
if self.start_position:
|
if self.start_position:
|
||||||
s += ", end_position: %d" % (self.end_position)
|
s += ", end_position: %d" % (self.end_position)
|
||||||
|
if self.start_position:
|
||||||
|
s += ", is_impossible: %r" % (self.is_impossible)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
@@ -99,7 +106,8 @@ class InputFeatures(object):
|
|||||||
input_mask,
|
input_mask,
|
||||||
segment_ids,
|
segment_ids,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None):
|
end_position=None,
|
||||||
|
is_impossible=None):
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.example_index = example_index
|
self.example_index = example_index
|
||||||
self.doc_span_index = doc_span_index
|
self.doc_span_index = doc_span_index
|
||||||
@@ -111,9 +119,10 @@ class InputFeatures(object):
|
|||||||
self.segment_ids = segment_ids
|
self.segment_ids = segment_ids
|
||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
|
self.is_impossible = is_impossible
|
||||||
|
|
||||||
|
|
||||||
def read_squad_examples(input_file, is_training):
|
def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||||
"""Read a SQuAD json file into a list of SquadExample."""
|
"""Read a SQuAD json file into a list of SquadExample."""
|
||||||
with open(input_file, "r", encoding='utf-8') as reader:
|
with open(input_file, "r", encoding='utf-8') as reader:
|
||||||
input_data = json.load(reader)["data"]
|
input_data = json.load(reader)["data"]
|
||||||
@@ -147,29 +156,37 @@ def read_squad_examples(input_file, is_training):
|
|||||||
start_position = None
|
start_position = None
|
||||||
end_position = None
|
end_position = None
|
||||||
orig_answer_text = None
|
orig_answer_text = None
|
||||||
|
is_impossible = False
|
||||||
if is_training:
|
if is_training:
|
||||||
if len(qa["answers"]) != 1:
|
if version_2_with_negative:
|
||||||
|
is_impossible = qa["is_impossible"]
|
||||||
|
if (len(qa["answers"]) != 1) and (not is_impossible):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For training, each question should have exactly 1 answer.")
|
"For training, each question should have exactly 1 answer.")
|
||||||
answer = qa["answers"][0]
|
if not is_impossible:
|
||||||
orig_answer_text = answer["text"]
|
answer = qa["answers"][0]
|
||||||
answer_offset = answer["answer_start"]
|
orig_answer_text = answer["text"]
|
||||||
answer_length = len(orig_answer_text)
|
answer_offset = answer["answer_start"]
|
||||||
start_position = char_to_word_offset[answer_offset]
|
answer_length = len(orig_answer_text)
|
||||||
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
start_position = char_to_word_offset[answer_offset]
|
||||||
# Only add answers where the text can be exactly recovered from the
|
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
||||||
# document. If this CAN'T happen it's likely due to weird Unicode
|
# Only add answers where the text can be exactly recovered from the
|
||||||
# stuff so we will just skip the example.
|
# document. If this CAN'T happen it's likely due to weird Unicode
|
||||||
#
|
# stuff so we will just skip the example.
|
||||||
# Note that this means for training mode, every example is NOT
|
#
|
||||||
# guaranteed to be preserved.
|
# Note that this means for training mode, every example is NOT
|
||||||
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
# guaranteed to be preserved.
|
||||||
cleaned_answer_text = " ".join(
|
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
|
||||||
whitespace_tokenize(orig_answer_text))
|
cleaned_answer_text = " ".join(
|
||||||
if actual_text.find(cleaned_answer_text) == -1:
|
whitespace_tokenize(orig_answer_text))
|
||||||
logger.warning("Could not find answer: '%s' vs. '%s'",
|
if actual_text.find(cleaned_answer_text) == -1:
|
||||||
|
logger.warning("Could not find answer: '%s' vs. '%s'",
|
||||||
actual_text, cleaned_answer_text)
|
actual_text, cleaned_answer_text)
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
start_position = -1
|
||||||
|
end_position = -1
|
||||||
|
orig_answer_text = ""
|
||||||
|
|
||||||
example = SquadExample(
|
example = SquadExample(
|
||||||
qas_id=qas_id,
|
qas_id=qas_id,
|
||||||
@@ -177,7 +194,8 @@ def read_squad_examples(input_file, is_training):
|
|||||||
doc_tokens=doc_tokens,
|
doc_tokens=doc_tokens,
|
||||||
orig_answer_text=orig_answer_text,
|
orig_answer_text=orig_answer_text,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position)
|
end_position=end_position,
|
||||||
|
is_impossible=is_impossible)
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
@@ -207,7 +225,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
tok_start_position = None
|
tok_start_position = None
|
||||||
tok_end_position = None
|
tok_end_position = None
|
||||||
if is_training:
|
if is_training and example.is_impossible:
|
||||||
|
tok_start_position = -1
|
||||||
|
tok_end_position = -1
|
||||||
|
if is_training and not example.is_impossible:
|
||||||
tok_start_position = orig_to_tok_index[example.start_position]
|
tok_start_position = orig_to_tok_index[example.start_position]
|
||||||
if example.end_position < len(example.doc_tokens) - 1:
|
if example.end_position < len(example.doc_tokens) - 1:
|
||||||
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
||||||
@@ -279,20 +300,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
start_position = None
|
start_position = None
|
||||||
end_position = None
|
end_position = None
|
||||||
if is_training:
|
if is_training and not example.is_impossible:
|
||||||
# For training, if our document chunk does not contain an annotation
|
# For training, if our document chunk does not contain an annotation
|
||||||
# we throw it out, since there is nothing to predict.
|
# we throw it out, since there is nothing to predict.
|
||||||
doc_start = doc_span.start
|
doc_start = doc_span.start
|
||||||
doc_end = doc_span.start + doc_span.length - 1
|
doc_end = doc_span.start + doc_span.length - 1
|
||||||
if (example.start_position < doc_start or
|
out_of_span = False
|
||||||
example.end_position < doc_start or
|
if not (tok_start_position >= doc_start and
|
||||||
example.start_position > doc_end or example.end_position > doc_end):
|
tok_end_position <= doc_end):
|
||||||
continue
|
out_of_span = True
|
||||||
|
if out_of_span:
|
||||||
doc_offset = len(query_tokens) + 2
|
start_position = 0
|
||||||
start_position = tok_start_position - doc_start + doc_offset
|
end_position = 0
|
||||||
end_position = tok_end_position - doc_start + doc_offset
|
else:
|
||||||
|
doc_offset = len(query_tokens) + 2
|
||||||
|
start_position = tok_start_position - doc_start + doc_offset
|
||||||
|
end_position = tok_end_position - doc_start + doc_offset
|
||||||
|
if is_training and example.is_impossible:
|
||||||
|
start_position = 0
|
||||||
|
end_position = 0
|
||||||
if example_index < 20:
|
if example_index < 20:
|
||||||
logger.info("*** Example ***")
|
logger.info("*** Example ***")
|
||||||
logger.info("unique_id: %s" % (unique_id))
|
logger.info("unique_id: %s" % (unique_id))
|
||||||
@@ -309,7 +335,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||||
if is_training:
|
if is_training and example.is_impossible:
|
||||||
|
logger.info("impossible example")
|
||||||
|
if is_training and not example.is_impossible:
|
||||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||||
logger.info("start_position: %d" % (start_position))
|
logger.info("start_position: %d" % (start_position))
|
||||||
logger.info("end_position: %d" % (end_position))
|
logger.info("end_position: %d" % (end_position))
|
||||||
@@ -328,7 +356,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
segment_ids=segment_ids,
|
segment_ids=segment_ids,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position))
|
end_position=end_position,
|
||||||
|
is_impossible=example.is_impossible))
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
|
||||||
return features
|
return features
|
||||||
@@ -408,15 +437,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
|||||||
return cur_span_index == best_span_index
|
return cur_span_index == best_span_index
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
RawResult = collections.namedtuple("RawResult",
|
RawResult = collections.namedtuple("RawResult",
|
||||||
["unique_id", "start_logits", "end_logits"])
|
["unique_id", "start_logits", "end_logits"])
|
||||||
|
|
||||||
|
|
||||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||||
max_answer_length, do_lower_case, output_prediction_file,
|
max_answer_length, do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, verbose_logging):
|
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||||
"""Write final predictions to the json file."""
|
version_2_with_negative, null_score_diff_threshold):
|
||||||
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
|
|
||||||
@@ -434,15 +463,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
|
|
||||||
all_predictions = collections.OrderedDict()
|
all_predictions = collections.OrderedDict()
|
||||||
all_nbest_json = collections.OrderedDict()
|
all_nbest_json = collections.OrderedDict()
|
||||||
|
scores_diff_json = collections.OrderedDict()
|
||||||
|
|
||||||
for (example_index, example) in enumerate(all_examples):
|
for (example_index, example) in enumerate(all_examples):
|
||||||
features = example_index_to_features[example_index]
|
features = example_index_to_features[example_index]
|
||||||
|
|
||||||
prelim_predictions = []
|
prelim_predictions = []
|
||||||
|
# keep track of the minimum score of null start+end of position 0
|
||||||
|
score_null = 1000000 # large and positive
|
||||||
|
min_null_feature_index = 0 # the paragraph slice with min mull score
|
||||||
|
null_start_logit = 0 # the start logit at the slice with min null score
|
||||||
|
null_end_logit = 0 # the end logit at the slice with min null score
|
||||||
for (feature_index, feature) in enumerate(features):
|
for (feature_index, feature) in enumerate(features):
|
||||||
result = unique_id_to_result[feature.unique_id]
|
result = unique_id_to_result[feature.unique_id]
|
||||||
|
|
||||||
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
||||||
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
||||||
|
# if we could have irrelevant answers, get the min score of irrelevant
|
||||||
|
if version_2_with_negative:
|
||||||
|
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
||||||
|
if feature_null_score < score_null:
|
||||||
|
score_null = feature_null_score
|
||||||
|
min_null_feature_index = feature_index
|
||||||
|
null_start_logit = result.start_logits[0]
|
||||||
|
null_end_logit = result.end_logits[0]
|
||||||
for start_index in start_indexes:
|
for start_index in start_indexes:
|
||||||
for end_index in end_indexes:
|
for end_index in end_indexes:
|
||||||
# We could hypothetically create invalid predictions, e.g., predict
|
# We could hypothetically create invalid predictions, e.g., predict
|
||||||
@@ -470,7 +513,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
end_index=end_index,
|
end_index=end_index,
|
||||||
start_logit=result.start_logits[start_index],
|
start_logit=result.start_logits[start_index],
|
||||||
end_logit=result.end_logits[end_index]))
|
end_logit=result.end_logits[end_index]))
|
||||||
|
if version_2_with_negative:
|
||||||
|
prelim_predictions.append(
|
||||||
|
_PrelimPrediction(
|
||||||
|
feature_index=min_null_feature_index,
|
||||||
|
start_index=0,
|
||||||
|
end_index=0,
|
||||||
|
start_logit=null_start_logit,
|
||||||
|
end_logit=null_end_logit))
|
||||||
prelim_predictions = sorted(
|
prelim_predictions = sorted(
|
||||||
prelim_predictions,
|
prelim_predictions,
|
||||||
key=lambda x: (x.start_logit + x.end_logit),
|
key=lambda x: (x.start_logit + x.end_logit),
|
||||||
@@ -485,33 +535,44 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
if len(nbest) >= n_best_size:
|
if len(nbest) >= n_best_size:
|
||||||
break
|
break
|
||||||
feature = features[pred.feature_index]
|
feature = features[pred.feature_index]
|
||||||
|
if pred.start_index > 0: # this is a non-null prediction
|
||||||
|
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||||
|
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||||
|
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||||
|
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||||
|
tok_text = " ".join(tok_tokens)
|
||||||
|
|
||||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
# De-tokenize WordPieces that have been split off.
|
||||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
tok_text = tok_text.replace(" ##", "")
|
||||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
tok_text = tok_text.replace("##", "")
|
||||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
|
||||||
tok_text = " ".join(tok_tokens)
|
|
||||||
|
|
||||||
# De-tokenize WordPieces that have been split off.
|
# Clean whitespace
|
||||||
tok_text = tok_text.replace(" ##", "")
|
tok_text = tok_text.strip()
|
||||||
tok_text = tok_text.replace("##", "")
|
tok_text = " ".join(tok_text.split())
|
||||||
|
orig_text = " ".join(orig_tokens)
|
||||||
|
|
||||||
# Clean whitespace
|
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
||||||
tok_text = tok_text.strip()
|
if final_text in seen_predictions:
|
||||||
tok_text = " ".join(tok_text.split())
|
continue
|
||||||
orig_text = " ".join(orig_tokens)
|
|
||||||
|
|
||||||
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
seen_predictions[final_text] = True
|
||||||
if final_text in seen_predictions:
|
else:
|
||||||
continue
|
final_text = ""
|
||||||
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
seen_predictions[final_text] = True
|
|
||||||
nbest.append(
|
nbest.append(
|
||||||
_NbestPrediction(
|
_NbestPrediction(
|
||||||
text=final_text,
|
text=final_text,
|
||||||
start_logit=pred.start_logit,
|
start_logit=pred.start_logit,
|
||||||
end_logit=pred.end_logit))
|
end_logit=pred.end_logit))
|
||||||
|
# if we didn't include the empty option in the n-best, include it
|
||||||
|
if version_2_with_negative:
|
||||||
|
if "" not in seen_predictions:
|
||||||
|
nbest.append(
|
||||||
|
_NbestPrediction(
|
||||||
|
text="",
|
||||||
|
start_logit=null_start_logit,
|
||||||
|
end_logit=null_end_logit))
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if not nbest:
|
if not nbest:
|
||||||
@@ -521,8 +582,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
assert len(nbest) >= 1
|
assert len(nbest) >= 1
|
||||||
|
|
||||||
total_scores = []
|
total_scores = []
|
||||||
|
best_non_null_entry = None
|
||||||
for entry in nbest:
|
for entry in nbest:
|
||||||
total_scores.append(entry.start_logit + entry.end_logit)
|
total_scores.append(entry.start_logit + entry.end_logit)
|
||||||
|
if not best_non_null_entry:
|
||||||
|
if entry.text:
|
||||||
|
best_non_null_entry = entry
|
||||||
|
|
||||||
probs = _compute_softmax(total_scores)
|
probs = _compute_softmax(total_scores)
|
||||||
|
|
||||||
@@ -537,8 +602,18 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
|
|
||||||
assert len(nbest_json) >= 1
|
assert len(nbest_json) >= 1
|
||||||
|
|
||||||
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
if not version_2_with_negative:
|
||||||
all_nbest_json[example.qas_id] = nbest_json
|
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
||||||
|
else:
|
||||||
|
# predict "" iff the null score - the score of best non-null > threshold
|
||||||
|
score_diff = score_null - best_non_null_entry.start_logit - (
|
||||||
|
best_non_null_entry.end_logit)
|
||||||
|
scores_diff_json[example.qas_id] = score_diff
|
||||||
|
if score_diff > null_score_diff_threshold:
|
||||||
|
all_predictions[example.qas_id] = ""
|
||||||
|
else:
|
||||||
|
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||||
|
all_nbest_json[example.qas_id] = nbest_json
|
||||||
|
|
||||||
with open(output_prediction_file, "w") as writer:
|
with open(output_prediction_file, "w") as writer:
|
||||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||||
@@ -546,6 +621,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
with open(output_nbest_file, "w") as writer:
|
with open(output_nbest_file, "w") as writer:
|
||||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
if version_2_with_negative:
|
||||||
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
|
||||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||||
"""Project the tokenized prediction back to the original text."""
|
"""Project the tokenized prediction back to the original text."""
|
||||||
@@ -608,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
|||||||
if len(orig_ns_text) != len(tok_ns_text):
|
if len(orig_ns_text) != len(tok_ns_text):
|
||||||
if verbose_logging:
|
if verbose_logging:
|
||||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
||||||
orig_ns_text, tok_ns_text)
|
orig_ns_text, tok_ns_text)
|
||||||
return orig_text
|
return orig_text
|
||||||
|
|
||||||
# We then project the characters in `pred_text` back to `orig_text` using
|
# We then project the characters in `pred_text` back to `orig_text` using
|
||||||
@@ -677,11 +756,6 @@ def _compute_softmax(scores):
|
|||||||
probs.append(score / total_sum)
|
probs.append(score / total_sum)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
|
||||||
if x < warmup:
|
|
||||||
return x/warmup
|
|
||||||
return 1.0 - x
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -713,7 +787,7 @@ def main():
|
|||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
|
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||||
"of training.")
|
"of training.")
|
||||||
parser.add_argument("--n_best_size", default=20, type=int,
|
parser.add_argument("--n_best_size", default=20, type=int,
|
||||||
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
||||||
@@ -750,7 +824,12 @@ def main():
|
|||||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||||
"0 (default value): dynamic loss scaling.\n"
|
"0 (default value): dynamic loss scaling.\n"
|
||||||
"Positive power of 2: static loss scaling value.\n")
|
"Positive power of 2: static loss scaling value.\n")
|
||||||
|
parser.add_argument('--version_2_with_negative',
|
||||||
|
action='store_true',
|
||||||
|
help='If true, the SQuAD examples contain some that do not have an answer.')
|
||||||
|
parser.add_argument('--null_score_diff_threshold',
|
||||||
|
type=float, default=0.0,
|
||||||
|
help="If null_score - best_non_null is greater than the threshold predict null.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.local_rank == -1 or args.no_cuda:
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
@@ -769,7 +848,7 @@ def main():
|
|||||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
args.gradient_accumulation_steps))
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||||
|
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -789,7 +868,7 @@ def main():
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If `do_predict` is True, then `predict_file` must be specified.")
|
"If `do_predict` is True, then `predict_file` must be specified.")
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||||
raise ValueError("Output directory () already exists and is not empty.")
|
raise ValueError("Output directory () already exists and is not empty.")
|
||||||
if not os.path.exists(args.output_dir):
|
if not os.path.exists(args.output_dir):
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
@@ -797,12 +876,14 @@ def main():
|
|||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_optimization_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_examples = read_squad_examples(
|
train_examples = read_squad_examples(
|
||||||
input_file=args.train_file, is_training=True)
|
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
||||||
num_train_steps = int(
|
num_train_optimization_steps = int(
|
||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
|
if args.local_rank != -1:
|
||||||
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
||||||
@@ -834,12 +915,9 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
|
|
||||||
t_total = num_train_steps
|
|
||||||
if args.local_rank != -1:
|
|
||||||
t_total = t_total // torch.distributed.get_world_size()
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FP16_Optimizer
|
from apex.optimizer import FP16_Optimizer
|
||||||
from apex.optimizers import FusedAdam
|
from apex.optimizers import FusedAdam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||||
@@ -856,7 +934,7 @@ def main():
|
|||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=t_total)
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
@@ -882,7 +960,7 @@ def main():
|
|||||||
logger.info(" Num orig examples = %d", len(train_examples))
|
logger.info(" Num orig examples = %d", len(train_examples))
|
||||||
logger.info(" Num split examples = %d", len(train_features))
|
logger.info(" Num split examples = %d", len(train_features))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||||
@@ -913,10 +991,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
# modify learning rate with special warm up BERT uses
|
if args.fp16:
|
||||||
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
|
# modify learning rate with special warm up BERT uses
|
||||||
for param_group in optimizer.param_groups:
|
# if args.fp16 is False, BertAdam is used and handles this automatically
|
||||||
param_group['lr'] = lr_this_step
|
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr_this_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
@@ -924,16 +1004,19 @@ def main():
|
|||||||
# Save a trained model
|
# Save a trained model
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
if args.do_train:
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
# Load a trained model that you have fine-tuned
|
||||||
|
model_state_dict = torch.load(output_model_file)
|
||||||
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
|
else:
|
||||||
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||||
|
|
||||||
# Load a trained model that you have fine-tuned
|
|
||||||
model_state_dict = torch.load(output_model_file)
|
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
eval_examples = read_squad_examples(
|
eval_examples = read_squad_examples(
|
||||||
input_file=args.predict_file, is_training=False)
|
input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
examples=eval_examples,
|
examples=eval_examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -977,10 +1060,12 @@ def main():
|
|||||||
end_logits=end_logits))
|
end_logits=end_logits))
|
||||||
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
||||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
||||||
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
|
||||||
write_predictions(eval_examples, eval_features, all_results,
|
write_predictions(eval_examples, eval_features, all_results,
|
||||||
args.n_best_size, args.max_answer_length,
|
args.n_best_size, args.max_answer_length,
|
||||||
args.do_lower_case, output_prediction_file,
|
args.do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, args.verbose_logging)
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
|
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from tqdm import tqdm, trange
|
|||||||
|
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
@@ -240,11 +240,6 @@ def select_field(features, field):
|
|||||||
for feature in features
|
for feature in features
|
||||||
]
|
]
|
||||||
|
|
||||||
def warmup_linear(x, warmup=0.002):
|
|
||||||
if x < warmup:
|
|
||||||
return x/warmup
|
|
||||||
return 1.0 - x
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -343,7 +338,7 @@ def main():
|
|||||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
args.gradient_accumulation_steps))
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||||
|
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -362,11 +357,13 @@ def main():
|
|||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
train_examples = None
|
train_examples = None
|
||||||
num_train_steps = None
|
num_train_optimization_steps = None
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
|
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
|
||||||
num_train_steps = int(
|
num_train_optimization_steps = int(
|
||||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
|
if args.local_rank != -1:
|
||||||
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||||
@@ -397,9 +394,6 @@ def main():
|
|||||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
]
|
]
|
||||||
t_total = num_train_steps
|
|
||||||
if args.local_rank != -1:
|
|
||||||
t_total = t_total // torch.distributed.get_world_size()
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex.optimizers import FP16_Optimizer
|
from apex.optimizers import FP16_Optimizer
|
||||||
@@ -419,7 +413,7 @@ def main():
|
|||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=t_total)
|
t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
@@ -428,7 +422,7 @@ def main():
|
|||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", len(train_examples))
|
logger.info(" Num examples = %d", len(train_examples))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||||
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
|
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
|
||||||
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
|
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
|
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
|
||||||
@@ -465,10 +459,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
# modify learning rate with special warm up BERT uses
|
if args.fp16:
|
||||||
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
|
# modify learning rate with special warm up BERT uses
|
||||||
for param_group in optimizer.param_groups:
|
# if args.fp16 is False, BertAdam is used that handles this automatically
|
||||||
param_group['lr'] = lr_this_step
|
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr_this_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|||||||
@@ -1067,7 +1067,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||||
a batch has varying length sentences.
|
a batch has varying length sentences.
|
||||||
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
|
||||||
with indices selected in [0, ..., num_labels].
|
with indices selected in [0, ..., num_labels].
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
@@ -1107,7 +1107,14 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
# Only keep active parts of the loss
|
||||||
|
if attention_mask is not None:
|
||||||
|
active_loss = attention_mask.view(-1) == 1
|
||||||
|
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||||
|
active_labels = labels.view(-1)[active_loss]
|
||||||
|
loss = loss_fct(active_logits, active_labels)
|
||||||
|
else:
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ def whitespace_tokenize(text):
|
|||||||
class BertTokenizer(object):
|
class BertTokenizer(object):
|
||||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None):
|
def __init__(self, vocab_file, do_lower_case=True, max_len=None,
|
||||||
|
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||||
if not os.path.isfile(vocab_file):
|
if not os.path.isfile(vocab_file):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||||
@@ -82,7 +83,8 @@ class BertTokenizer(object):
|
|||||||
self.vocab = load_vocab(vocab_file)
|
self.vocab = load_vocab(vocab_file)
|
||||||
self.ids_to_tokens = collections.OrderedDict(
|
self.ids_to_tokens = collections.OrderedDict(
|
||||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
|
||||||
|
never_split=never_split)
|
||||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
|
||||||
@@ -155,13 +157,16 @@ class BertTokenizer(object):
|
|||||||
class BasicTokenizer(object):
|
class BasicTokenizer(object):
|
||||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||||
|
|
||||||
def __init__(self, do_lower_case=True):
|
def __init__(self,
|
||||||
|
do_lower_case=True,
|
||||||
|
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||||
"""Constructs a BasicTokenizer.
|
"""Constructs a BasicTokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
do_lower_case: Whether to lower case the input.
|
do_lower_case: Whether to lower case the input.
|
||||||
"""
|
"""
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
self.never_split = never_split
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
"""Tokenizes a piece of text."""
|
"""Tokenizes a piece of text."""
|
||||||
@@ -176,7 +181,7 @@ class BasicTokenizer(object):
|
|||||||
orig_tokens = whitespace_tokenize(text)
|
orig_tokens = whitespace_tokenize(text)
|
||||||
split_tokens = []
|
split_tokens = []
|
||||||
for token in orig_tokens:
|
for token in orig_tokens:
|
||||||
if self.do_lower_case:
|
if self.do_lower_case and token not in self.never_split:
|
||||||
token = token.lower()
|
token = token.lower()
|
||||||
token = self._run_strip_accents(token)
|
token = self._run_strip_accents(token)
|
||||||
split_tokens.extend(self._run_split_on_punc(token))
|
split_tokens.extend(self._run_split_on_punc(token))
|
||||||
@@ -197,6 +202,8 @@ class BasicTokenizer(object):
|
|||||||
|
|
||||||
def _run_split_on_punc(self, text):
|
def _run_split_on_punc(self, text):
|
||||||
"""Splits punctuation on a piece of text."""
|
"""Splits punctuation on a piece of text."""
|
||||||
|
if text in self.never_split:
|
||||||
|
return [text]
|
||||||
chars = list(text)
|
chars = list(text)
|
||||||
i = 0
|
i = 0
|
||||||
start_new_word = True
|
start_new_word = True
|
||||||
|
|||||||
Reference in New Issue
Block a user