updating model loading and adding special tokens ids
This commit is contained in:
@@ -6,14 +6,13 @@ import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
model = XLNetModel.from_pretrained('xlnet-large-cased')
|
||||
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
|
||||
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', attn_type='uni')
|
||||
|
||||
tokens = tokenizer.encode('I am very ')
|
||||
tokens = tokenizer.encode('I am very happy')
|
||||
for i in range(len(tokens), 20):
|
||||
mask = torch.tensor([[[0.0] * i + [1.0]]])
|
||||
logits, _ = model(torch.tensor([tokens + [0]]),
|
||||
perm_mask=mask.expand(-1, i+1, -1),
|
||||
# perm_mask=mask.expand(-1, i+1, -1),
|
||||
target_mapping=mask,
|
||||
inp_q=mask.squeeze(1))
|
||||
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
|
||||
|
||||
@@ -730,12 +730,17 @@ class XLNetPreTrainedModel(nn.Module):
|
||||
|
||||
# Load config
|
||||
config = XLNetConfig.from_json_file(resolved_config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
# Update config with kwargs if needed
|
||||
for key, value in kwargs:
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
|
||||
@@ -36,7 +36,29 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
VOCAB_NAME = 'spiece.model'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
SPIECE_UNDERLINE = '▁'
|
||||
SPIECE_UNDERLINE = u'▁'
|
||||
|
||||
# Tokens
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
UNK_ID = special_symbols["<unk>"]
|
||||
CLS_ID = special_symbols["<cls>"]
|
||||
SEP_ID = special_symbols["<sep>"]
|
||||
MASK_ID = special_symbols["<mask>"]
|
||||
EOD_ID = special_symbols["<eod>"]
|
||||
|
||||
# Segments (not really needed)
|
||||
SEG_ID_A = 0
|
||||
SEG_ID_B = 1
|
||||
SEG_ID_CLS = 2
|
||||
|
||||
1
xlnet
Submodule
1
xlnet
Submodule
Submodule xlnet added at cbdedecbc7
Reference in New Issue
Block a user