various fixes
This commit is contained in:
@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
|
|||||||
model = modeling.BertModel(config, num_labels)
|
model = modeling.BertModel(config, num_labels)
|
||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
""" def __init__(self, config, num_labels):
|
"""
|
||||||
|
def __init__(self, config, num_labels):
|
||||||
super(BertForSequenceClassification, self).__init__()
|
super(BertForSequenceClassification, self).__init__()
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
|
|||||||
type = str,
|
type = str,
|
||||||
help = "Initial checkpoint (usually from a pre-trained BERT model).")
|
help = "Initial checkpoint (usually from a pre-trained BERT model).")
|
||||||
parser.add_argument("--do_lower_case",
|
parser.add_argument("--do_lower_case",
|
||||||
default = True,
|
default = False,
|
||||||
type = bool,
|
action='store_true',
|
||||||
help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.")
|
help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.")
|
||||||
parser.add_argument("--max_seq_length",
|
parser.add_argument("--max_seq_length",
|
||||||
default = 128,
|
default = 128,
|
||||||
@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
|
|||||||
"than this will be padded.")
|
"than this will be padded.")
|
||||||
parser.add_argument("--do_train",
|
parser.add_argument("--do_train",
|
||||||
default = False,
|
default = False,
|
||||||
type = bool,
|
action='store_true',
|
||||||
help = "Whether to run training.")
|
help = "Whether to run training.")
|
||||||
parser.add_argument("--do_eval",
|
parser.add_argument("--do_eval",
|
||||||
default = False,
|
default = False,
|
||||||
type = bool,
|
action='store_true',
|
||||||
help = "Whether to run eval on the dev set.")
|
help = "Whether to run eval on the dev set.")
|
||||||
parser.add_argument("--train_batch_size",
|
parser.add_argument("--train_batch_size",
|
||||||
default = 32,
|
default = 32,
|
||||||
@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
|
|||||||
help = "How often to save the model checkpoint.")
|
help = "How often to save the model checkpoint.")
|
||||||
parser.add_argument("--no_cuda",
|
parser.add_argument("--no_cuda",
|
||||||
default = False,
|
default = False,
|
||||||
type = bool,
|
action='store_true',
|
||||||
help = "Whether not to use CUDA when available")
|
help = "Whether not to use CUDA when available")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -490,6 +490,7 @@ def main():
|
|||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
t_total=num_train_steps)
|
t_total=num_train_steps)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
train_examples, label_list, args.max_seq_length, tokenizer)
|
train_examples, label_list, args.max_seq_length, tokenizer)
|
||||||
@@ -511,7 +512,6 @@ def main():
|
|||||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
global_step = 0
|
|
||||||
for epoch in args.num_train_epochs:
|
for epoch in args.num_train_epochs:
|
||||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
@@ -552,9 +552,11 @@ def main():
|
|||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.float().to(device)
|
input_mask = input_mask.float().to(device)
|
||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
label_ids = label_ids.to(device)
|
|
||||||
|
|
||||||
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
|
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
|
|
||||||
|
logits = logits.detach().cpu().numpy()
|
||||||
|
label_ids = label_ids.to('cpu').numpy()
|
||||||
tmp_eval_accuracy = accuracy(logits, label_ids)
|
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||||
|
|
||||||
eval_loss += tmp_eval_loss.item()
|
eval_loss += tmp_eval_loss.item()
|
||||||
|
|||||||
Reference in New Issue
Block a user