more versatile model loading
This commit is contained in:
@@ -606,7 +606,9 @@ class BertPreTrainedModel(nn.Module):
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
start_prefix = 'bert.' if not hasattr(model, 'bert') and any(s.startwith('bert.') for s in state_dict.keys()) else ''
|
||||
start_prefix = ''
|
||||
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
|
||||
start_prefix = 'bert.'
|
||||
load(model, prefix=start_prefix)
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
|
||||
@@ -120,6 +120,7 @@ class OpenAIGPTConfig(object):
|
||||
self,
|
||||
vocab_size_or_config_json_file=40478,
|
||||
n_special=0,
|
||||
n_positions=512,
|
||||
n_ctx=512,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
@@ -135,7 +136,8 @@ class OpenAIGPTConfig(object):
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
|
||||
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
|
||||
n_ctx: Number of positional embeddings.
|
||||
n_positions: Number of positional embeddings.
|
||||
n_ctx: Size of the causal mask (usually same as n_positions).
|
||||
n_embd: Dimensionality of the embeddings and hidden states.
|
||||
n_layer: Number of hidden layers in the Transformer encoder.
|
||||
n_head: Number of attention heads for each attention layer in
|
||||
@@ -159,6 +161,7 @@ class OpenAIGPTConfig(object):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.n_special = n_special
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
@@ -175,7 +178,7 @@ class OpenAIGPTConfig(object):
|
||||
|
||||
@property
|
||||
def total_num_embeddings(self):
|
||||
return self.vocab_size + self.n_special + self.n_ctx
|
||||
return self.vocab_size + self.n_special + self.n_positions
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
@@ -234,7 +237,7 @@ class Attention(nn.Module):
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
@@ -247,9 +250,9 @@ class Attention(nn.Module):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
|
||||
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
|
||||
# XD: self.b may be larger than w, so we need to crop it
|
||||
b = self.b[:, :, : w.size(-2), : w.size(-1)]
|
||||
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
|
||||
w = w * b + -1e9 * (1 - b)
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
@@ -474,10 +477,12 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
new_key = key.replace("gamma", "weight")
|
||||
if "beta" in key:
|
||||
new_key = key.replace("beta", "bias")
|
||||
if key.endswith(".g"):
|
||||
new_key = key[:-2] + ".weight"
|
||||
elif key.endswith(".b"):
|
||||
new_key = key[:-2] + ".bias"
|
||||
elif key.endswith(".w"):
|
||||
new_key = key[:-2] + ".weight"
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
@@ -502,7 +507,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
if hasattr(model, "transformer") and all(not s.startwith('transformer.') for s in state_dict.keys()):
|
||||
start_model = model
|
||||
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
|
||||
start_model = model.transformer
|
||||
load(start_model, prefix="")
|
||||
|
||||
@@ -541,7 +547,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
total_num_embeddings - 1] ______________________
|
||||
|
||||
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
|
||||
You should use the associate indices to index the embeddings.
|
||||
|
||||
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
|
||||
@@ -554,7 +560,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
||||
to each token in the sentence.
|
||||
@@ -578,7 +584,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
|
||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
|
||||
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
@@ -598,7 +604,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
self.init_weights(self.embed)
|
||||
# Copy word and positional embeddings from the previous weights
|
||||
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
||||
self.embed.weight.data[-self.config.n_ctx :, :] = old_embed.weight.data[-self.config.n_ctx :, :]
|
||||
self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
if position_ids is None:
|
||||
@@ -645,7 +651,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
total_num_embeddings - 1] ______________________
|
||||
|
||||
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
|
||||
You should use these indices to index the word, special and position embeddings.
|
||||
|
||||
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
|
||||
@@ -658,7 +664,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
||||
to each token in the sentence.
|
||||
@@ -725,7 +731,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
total_num_embeddings - 1] ______________________
|
||||
|
||||
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
|
||||
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
|
||||
You should use these indices to index the word, special and position embeddings.
|
||||
|
||||
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
|
||||
@@ -741,7 +747,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special,
|
||||
config.vocab_size + config.n_special + config.n_ctx - 1[.
|
||||
config.vocab_size + config.n_special + config.n_positions - 1[.
|
||||
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
||||
to each token in the sentence.
|
||||
|
||||
Reference in New Issue
Block a user