From ed47cb6cbaa8fb039117b67ee2d828231b24346c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 6 Feb 2019 16:22:17 +0100 Subject: [PATCH] fixing transfo eval script --- examples/eval_transfo_xl.py | 2 +- pytorch_pretrained_bert/modeling_transfo_xl.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/eval_transfo_xl.py b/examples/eval_transfo_xl.py index 3326454ea1..9a0975f186 100644 --- a/examples/eval_transfo_xl.py +++ b/examples/eval_transfo_xl.py @@ -111,7 +111,7 @@ def evaluate(eval_iter): mems = tuple() for idx, (data, target, seq_len) in enumerate(eval_iter): ret = model(data, target, *mems) - loss, mems = ret[0], ret[1:] + loss, mems = ret loss = loss.mean() total_loss += seq_len * loss.item() total_len += seq_len diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 000d7ac19b..53ebca6e92 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -1215,7 +1215,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. - if not mems: mems = self.init_mems(data) + if not mems: + mems = self.init_mems(data) hidden, new_mems = self._forward(data, mems=mems) if target is None: