update from_pretrained to load XLNetModel as well

This commit is contained in:
thomwolf
2019-06-21 21:08:44 +02:00
parent 483cbc36a9
commit ebd2cb8d74
4 changed files with 99 additions and 36 deletions

View File

@@ -0,0 +1,21 @@
import torch
from torch.nn import functional as F
from pytorch_pretrained_bert import XLNetModel, XLNetLMHeadModel, XLNetTokenizer
import logging
logging.basicConfig(level=logging.INFO)
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
tokens = tokenizer.encode('I am very ')
for i in range(len(tokens), 20):
mask = torch.tensor([[[0.0] * i + [1.0]]])
logits, _ = model(torch.tensor([tokens + [0]]),
perm_mask=mask.expand(-1, i+1, -1),
target_mapping=mask,
inp_q=mask.squeeze(1))
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
tokens.append(output.item())
print(tokenizer.decode(tokens))