updating model loading and adding special tokens ids

This commit is contained in:
thomwolf
2019-06-21 23:23:37 +02:00
parent ebd2cb8d74
commit 181075635d
4 changed files with 34 additions and 7 deletions

View File

@@ -6,14 +6,13 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased') model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', attn_type='uni')
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
tokens = tokenizer.encode('I am very ') tokens = tokenizer.encode('I am very happy')
for i in range(len(tokens), 20): for i in range(len(tokens), 20):
mask = torch.tensor([[[0.0] * i + [1.0]]]) mask = torch.tensor([[[0.0] * i + [1.0]]])
logits, _ = model(torch.tensor([tokens + [0]]), logits, _ = model(torch.tensor([tokens + [0]]),
perm_mask=mask.expand(-1, i+1, -1), # perm_mask=mask.expand(-1, i+1, -1),
target_mapping=mask, target_mapping=mask,
inp_q=mask.squeeze(1)) inp_q=mask.squeeze(1))
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1) output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)

View File

@@ -730,12 +730,17 @@ class XLNetPreTrainedModel(nn.Module):
# Load config # Load config
config = XLNetConfig.from_json_file(resolved_config_file) config = XLNetConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Update config with kwargs if needed # Update config with kwargs if needed
for key, value in kwargs: to_remove = []
for key, value in kwargs.items():
if hasattr(config, key): if hasattr(config, key):
setattr(config, key, value) setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)

View File

@@ -36,7 +36,29 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
VOCAB_NAME = 'spiece.model' VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
SPIECE_UNDERLINE = '' SPIECE_UNDERLINE = u''
# Tokens
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
# Segments (not really needed)
SEG_ID_A = 0 SEG_ID_A = 0
SEG_ID_B = 1 SEG_ID_B = 1
SEG_ID_CLS = 2 SEG_ID_CLS = 2

1
xlnet Submodule

Submodule xlnet added at cbdedecbc7