updating model loading and adding special tokens ids
This commit is contained in:
@@ -6,14 +6,13 @@ import logging
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||||
model = XLNetModel.from_pretrained('xlnet-large-cased')
|
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', attn_type='uni')
|
||||||
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
|
|
||||||
|
|
||||||
tokens = tokenizer.encode('I am very ')
|
tokens = tokenizer.encode('I am very happy')
|
||||||
for i in range(len(tokens), 20):
|
for i in range(len(tokens), 20):
|
||||||
mask = torch.tensor([[[0.0] * i + [1.0]]])
|
mask = torch.tensor([[[0.0] * i + [1.0]]])
|
||||||
logits, _ = model(torch.tensor([tokens + [0]]),
|
logits, _ = model(torch.tensor([tokens + [0]]),
|
||||||
perm_mask=mask.expand(-1, i+1, -1),
|
# perm_mask=mask.expand(-1, i+1, -1),
|
||||||
target_mapping=mask,
|
target_mapping=mask,
|
||||||
inp_q=mask.squeeze(1))
|
inp_q=mask.squeeze(1))
|
||||||
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
|
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
|
||||||
|
|||||||
@@ -730,12 +730,17 @@ class XLNetPreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config = XLNetConfig.from_json_file(resolved_config_file)
|
config = XLNetConfig.from_json_file(resolved_config_file)
|
||||||
logger.info("Model config {}".format(config))
|
|
||||||
|
|
||||||
# Update config with kwargs if needed
|
# Update config with kwargs if needed
|
||||||
for key, value in kwargs:
|
to_remove = []
|
||||||
|
for key, value in kwargs.items():
|
||||||
if hasattr(config, key):
|
if hasattr(config, key):
|
||||||
setattr(config, key, value)
|
setattr(config, key, value)
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
kwargs.pop(key, None)
|
||||||
|
|
||||||
|
logger.info("Model config {}".format(config))
|
||||||
|
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
|
|||||||
@@ -36,7 +36,29 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
|||||||
VOCAB_NAME = 'spiece.model'
|
VOCAB_NAME = 'spiece.model'
|
||||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||||
|
|
||||||
SPIECE_UNDERLINE = '▁'
|
SPIECE_UNDERLINE = u'▁'
|
||||||
|
|
||||||
|
# Tokens
|
||||||
|
special_symbols = {
|
||||||
|
"<unk>" : 0,
|
||||||
|
"<s>" : 1,
|
||||||
|
"</s>" : 2,
|
||||||
|
"<cls>" : 3,
|
||||||
|
"<sep>" : 4,
|
||||||
|
"<pad>" : 5,
|
||||||
|
"<mask>" : 6,
|
||||||
|
"<eod>" : 7,
|
||||||
|
"<eop>" : 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
VOCAB_SIZE = 32000
|
||||||
|
UNK_ID = special_symbols["<unk>"]
|
||||||
|
CLS_ID = special_symbols["<cls>"]
|
||||||
|
SEP_ID = special_symbols["<sep>"]
|
||||||
|
MASK_ID = special_symbols["<mask>"]
|
||||||
|
EOD_ID = special_symbols["<eod>"]
|
||||||
|
|
||||||
|
# Segments (not really needed)
|
||||||
SEG_ID_A = 0
|
SEG_ID_A = 0
|
||||||
SEG_ID_B = 1
|
SEG_ID_B = 1
|
||||||
SEG_ID_CLS = 2
|
SEG_ID_CLS = 2
|
||||||
|
|||||||
1
xlnet
Submodule
1
xlnet
Submodule
Submodule xlnet added at cbdedecbc7
Reference in New Issue
Block a user