add model.zero_grad()
This commit is contained in:
@@ -531,6 +531,7 @@ def main():
|
|||||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
total_tr_loss += loss.item()
|
total_tr_loss += loss.item()
|
||||||
nb_tr_examples += input_ids.size(0)
|
nb_tr_examples += input_ids.size(0)
|
||||||
|
model.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|||||||
@@ -856,6 +856,7 @@ def main():
|
|||||||
|
|
||||||
logger.info("HHHHH Forward")
|
logger.info("HHHHH Forward")
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||||
|
model.zero_grad()
|
||||||
logger.info("HHHHH Backward")
|
logger.info("HHHHH Backward")
|
||||||
loss.backward()
|
loss.backward()
|
||||||
logger.info("HHHHH Loading data")
|
logger.info("HHHHH Loading data")
|
||||||
|
|||||||
Reference in New Issue
Block a user