mems not splitted

This commit is contained in:
thomwolf
2019-02-09 16:14:31 +01:00
parent 43b9af0cac
commit f4a07a392c

View File

@@ -102,7 +102,7 @@ def main():
with torch.no_grad():
mems = None
for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, *mems)
ret = model(data, target, mems)
loss, mems = ret
loss = loss.mean()
total_loss += seq_len * loss.item()