From 466a96543a46fb5328667415b7170e57611867c2 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Sat, 1 Jun 2019 17:28:56 -0400 Subject: [PATCH] fix bug/typos --- hubconfs/transformer_xl_hubconf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hubconfs/transformer_xl_hubconf.py b/hubconfs/transformer_xl_hubconf.py index 0bf7710553..68ffb19fb5 100644 --- a/hubconfs/transformer_xl_hubconf.py +++ b/hubconfs/transformer_xl_hubconf.py @@ -86,7 +86,7 @@ def transformerXLModel(*args, **kwargs): # We can re-use the memory cells in a subsequent call to attend a longer context >>> with torch.no_grad(): hidden_states_1, mems_1 = model(tokens_tensor_1) - hidden_states_2, past = model(tokens_tensor_2, past=past) + hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1) """ model = TransfoXLModel.from_pretrained(*args, **kwargs) return model @@ -121,7 +121,7 @@ def transformerXLLMHeadModel(*args, **kwargs): # We can re-use the memory cells in a subsequent call to attend a longer context >>> with torch.no_grad(): predictions_1, mems_1 = model(tokens_tensor_1) - predictions_2, past = model(tokens_tensor_2, past=past) + predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1) # Get the predicted last token >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()