fix bug/typos
This commit is contained in:
@@ -86,7 +86,7 @@ def transformerXLModel(*args, **kwargs):
|
|||||||
# We can re-use the memory cells in a subsequent call to attend a longer context
|
# We can re-use the memory cells in a subsequent call to attend a longer context
|
||||||
>>> with torch.no_grad():
|
>>> with torch.no_grad():
|
||||||
hidden_states_1, mems_1 = model(tokens_tensor_1)
|
hidden_states_1, mems_1 = model(tokens_tensor_1)
|
||||||
hidden_states_2, past = model(tokens_tensor_2, past=past)
|
hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
|
||||||
"""
|
"""
|
||||||
model = TransfoXLModel.from_pretrained(*args, **kwargs)
|
model = TransfoXLModel.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
@@ -121,7 +121,7 @@ def transformerXLLMHeadModel(*args, **kwargs):
|
|||||||
# We can re-use the memory cells in a subsequent call to attend a longer context
|
# We can re-use the memory cells in a subsequent call to attend a longer context
|
||||||
>>> with torch.no_grad():
|
>>> with torch.no_grad():
|
||||||
predictions_1, mems_1 = model(tokens_tensor_1)
|
predictions_1, mems_1 = model(tokens_tensor_1)
|
||||||
predictions_2, past = model(tokens_tensor_2, past=past)
|
predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
|
||||||
|
|
||||||
# Get the predicted last token
|
# Get the predicted last token
|
||||||
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
|
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
|
||||||
|
|||||||
Reference in New Issue
Block a user