Fix args.gradient_accumulation_steps used before assigment.
This commit is contained in:
@@ -404,6 +404,10 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
default=42,
|
default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--gradient_accumulation_steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
processors = {
|
processors = {
|
||||||
@@ -469,7 +473,7 @@ def main():
|
|||||||
|
|
||||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||||
if args.init_checkpoint is not None:
|
if args.init_checkpoint is not None:
|
||||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
|
|||||||
@@ -739,6 +739,10 @@ def main():
|
|||||||
type=int,
|
type=int,
|
||||||
default=42,
|
default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--gradient_accumulation_steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumualte before performing a backward/update pass.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user