update from_pretrained to load XLNetModel as well
This commit is contained in:
21
examples/generation_xlnet.py
Normal file
21
examples/generation_xlnet.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from pytorch_pretrained_bert import XLNetModel, XLNetLMHeadModel, XLNetTokenizer
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
model = XLNetModel.from_pretrained('xlnet-large-cased')
|
||||
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
|
||||
|
||||
tokens = tokenizer.encode('I am very ')
|
||||
for i in range(len(tokens), 20):
|
||||
mask = torch.tensor([[[0.0] * i + [1.0]]])
|
||||
logits, _ = model(torch.tensor([tokens + [0]]),
|
||||
perm_mask=mask.expand(-1, i+1, -1),
|
||||
target_mapping=mask,
|
||||
inp_q=mask.squeeze(1))
|
||||
output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
|
||||
tokens.append(output.item())
|
||||
print(tokenizer.decode(tokens))
|
||||
@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module):
|
||||
archive_file, resolved_archive_file))
|
||||
logger.info("loading configuration file {} from cache at {}".format(
|
||||
config_file, resolved_config_file))
|
||||
|
||||
# Load config
|
||||
config = XLNetConfig.from_json_file(resolved_config_file)
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
# Update config with kwargs if needed
|
||||
for key, value in kwargs:
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
return load_tf_weights_in_xlnet(model, resolved_archive_file)
|
||||
return load_tf_weights_in_xlnet(model, config, resolved_archive_file)
|
||||
|
||||
# Load from a PyTorch state_dict
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module):
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
start_prefix = ''
|
||||
if not hasattr(model, 'xlnet') and any(s.startswith('xlnet.') for s in state_dict.keys()):
|
||||
start_prefix = 'xlnet.'
|
||||
if not hasattr(model, 'transformer') and any(s.startswith('transformer') for s in state_dict.keys()):
|
||||
start_prefix = 'transformer.'
|
||||
load(model, prefix=start_prefix)
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
output_h = self.dropout(word_emb_k)
|
||||
if inp_q is not None:
|
||||
if target_mapping is not None:
|
||||
word_emb_q = mask_emb.expand(target_mapping.shape[0], bsz, -1)
|
||||
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
|
||||
else:
|
||||
inp_q_ext = inp_q[:, :, None]
|
||||
word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
|
||||
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
|
||||
output_g = self.dropout(word_emb_q)
|
||||
else:
|
||||
output_g = None
|
||||
@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
This can be used to compute head importance metrics. Default: False
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see XLNet paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
|
||||
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
|
||||
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
|
||||
0 for real tokens and 1 for padding.
|
||||
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
from previous batches. The length of the list equals n_layer.
|
||||
If None, no memory is used.
|
||||
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
|
||||
If perm_mask[k, i, j] = 0, i attend to j in batch k;
|
||||
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
|
||||
If None, each position attends to all the others.
|
||||
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
|
||||
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
|
||||
on the j-th token.
|
||||
Only used during pretraining for partial prediction.
|
||||
Set to None during finetuning.
|
||||
inp_q: [optional] float32 Tensor in shape [bsz, len].
|
||||
1 for tokens with losses and 0 for tokens without losses.
|
||||
Only used during pretraining for two-stream attention.
|
||||
Set to None during finetuning.
|
||||
|
||||
|
||||
Outputs: Tuple of (encoded_layers, pooled_output)
|
||||
|
||||
@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
SPIECE_UNDERLINE = '▁'
|
||||
SEG_ID_A = 0
|
||||
SEG_ID_B = 1
|
||||
SEG_ID_CLS = 2
|
||||
SEG_ID_SEP = 3
|
||||
SEG_ID_PAD = 4
|
||||
|
||||
class XLNetTokenizer(object):
|
||||
"""
|
||||
@@ -52,6 +57,16 @@ class XLNetTokenizer(object):
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
special_tokens_file = None
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase):
|
||||
input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
|
||||
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
|
||||
# input_mask: float32 Tensor in shape [len, bsz], the input mask.
|
||||
input_ids_q = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
|
||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
inp_q = target_mapping[:, 0, :].clone() # predict last token
|
||||
|
||||
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
|
||||
# seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
|
||||
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
|
||||
# 0 for real tokens and 1 for padding.
|
||||
# mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
|
||||
# mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory
|
||||
# from previous batches. The length of the list equals n_layer.
|
||||
# If None, no memory is used.
|
||||
# perm_mask: float32 Tensor in shape [len, len, bsz].
|
||||
# If perm_mask[i, j, k] = 0, i attend to j in batch k;
|
||||
# if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
|
||||
# perm_mask: float32 Tensor in shape [bsz, len, len].
|
||||
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
|
||||
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
|
||||
# If None, each position attends to all the others.
|
||||
# target_mapping: float32 Tensor in shape [num_predict, len, bsz].
|
||||
# If target_mapping[i, j, k] = 1, the i-th predict in batch k is
|
||||
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
|
||||
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
|
||||
# on the j-th token.
|
||||
# Only used during pretraining for partial prediction.
|
||||
# Set to None during finetuning.
|
||||
# inp_q: float32 Tensor in shape [len, bsz].
|
||||
# inp_q: float32 Tensor in shape [bsz, len].
|
||||
# 1 for tokens with losses and 0 for tokens without losses.
|
||||
# Only used during pretraining for two-stream attention.
|
||||
# Set to None during finetuning.
|
||||
@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase):
|
||||
|
||||
config.update(run_config)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, segment_ids, lm_labels)
|
||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, segment_ids, lm_labels):
|
||||
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss_1, mems_1a = model(input_ids_1, seg_id=segment_ids, target=lm_labels)
|
||||
lm_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
|
||||
all_logits_1, mems_1b = model(input_ids_1, seg_id=segment_ids)
|
||||
|
||||
loss_2, mems_2a = model(input_ids_2, seg_id=segment_ids, target=lm_labels, mems=mems_1a)
|
||||
lm_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
|
||||
all_logits_2, mems_2b = model(input_ids_2, seg_id=segment_ids, mems=mems_1b)
|
||||
|
||||
logits, _ = model(input_ids_q,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
inp_q=inp_q)
|
||||
|
||||
outputs = {
|
||||
"loss_1": loss_1,
|
||||
"mems_1a": mems_1a,
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"all_logits_1": all_logits_1,
|
||||
"mems_1b": mems_1b,
|
||||
"loss_2": loss_2,
|
||||
"mems_2a": mems_2a,
|
||||
"lm_logits_2": lm_logits_2,
|
||||
"all_logits_2": all_logits_2,
|
||||
"mems_2b": mems_2b,
|
||||
}
|
||||
return outputs
|
||||
@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
|
||||
list(result["loss_1"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
list(result["all_logits_1"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1a"]),
|
||||
@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase):
|
||||
list(result["loss_2"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
list(result["all_logits_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2a"]),
|
||||
|
||||
Reference in New Issue
Block a user