cache in run_classifier + various fixes to the examples
This commit is contained in:
@@ -18,10 +18,7 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
@@ -301,9 +298,6 @@ def main():
|
||||
else:
|
||||
loss.backward()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||
if args.fp16:
|
||||
# modify learning rate with special warm up BERT uses
|
||||
# if args.fp16 is False, BertAdam is used and handles this automatically
|
||||
@@ -313,6 +307,9 @@ def main():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Save a trained model, configuration and tokenizer
|
||||
|
||||
Reference in New Issue
Block a user