cache in run_classifier + various fixes to the examples

This commit is contained in:
thomwolf
2019-06-18 15:58:22 +02:00
parent e6e5f19257
commit 15ebd67d4e
5 changed files with 665 additions and 624 deletions

View File

@@ -18,10 +18,7 @@
from __future__ import absolute_import, division, print_function
import argparse
import collections
import json
import logging
import math
import os
import random
import sys
@@ -301,9 +298,6 @@ def main():
else:
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.local_rank in [-1, 0]:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
if args.fp16:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used and handles this automatically
@@ -313,6 +307,9 @@ def main():
optimizer.step()
optimizer.zero_grad()
global_step += 1
if args.local_rank in [-1, 0]:
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
tb_writer.add_scalar('loss', loss.item(), global_step)
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer