fixing transfo eval script

This commit is contained in:
thomwolf
2019-02-06 16:22:17 +01:00
parent 973926431e
commit ed47cb6cba
2 changed files with 3 additions and 2 deletions

View File

@@ -111,7 +111,7 @@ def evaluate(eval_iter):
mems = tuple()
for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss, mems = ret
loss = loss.mean()
total_loss += seq_len * loss.item()
total_len += seq_len