examples testing
This commit is contained in:
@@ -4,7 +4,9 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
|
||||||
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
|
|||||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
||||||
if start_token is None:
|
if start_token is None:
|
||||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||||
context = torch.tensor(context, device=device)
|
context = torch.tensor(context, device=device, dtype=torch.long)
|
||||||
else:
|
else:
|
||||||
assert context is None, 'Specify exactly one of start_token and context!'
|
assert context is None, 'Specify exactly one of start_token and context!'
|
||||||
context = torch.full((batch_size, 1), start_token, device=device)
|
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
|
||||||
prev = context
|
prev = context
|
||||||
output = context
|
output = context
|
||||||
|
past = None
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(length):
|
for i in trange(length):
|
||||||
logits, past = model(prev, past=past)
|
logits, past = model(prev, past=past)
|
||||||
logits = logits[:, -1, :] / temperature
|
logits = logits[:, -1, :] / temperature
|
||||||
logits = top_k_logits(logits, k=top_k)
|
logits = top_k_logits(logits, k=top_k)
|
||||||
prev = torch.multinomial(logits, 1)
|
log_probs = F.softmax(logits, dim=-1)
|
||||||
|
prev = torch.multinomial(log_probs, num_samples=1)
|
||||||
output = torch.cat((output, prev), dim=1)
|
output = torch.cat((output, prev), dim=1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -57,6 +61,8 @@ def sample_model():
|
|||||||
|
|
||||||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
if args.length == -1:
|
if args.length == -1:
|
||||||
args.length = model.config.n_ctx
|
args.length = model.config.n_ctx
|
||||||
@@ -71,6 +77,7 @@ def sample_model():
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
temperature=args.temperature, top_k=args.top_k, device=device
|
temperature=args.temperature, top_k=args.top_k, device=device
|
||||||
)
|
)
|
||||||
|
out = out.tolist()
|
||||||
for i in range(args.batch_size):
|
for i in range(args.batch_size):
|
||||||
generated += args.batch_size
|
generated += args.batch_size
|
||||||
text = enc.decode(out[i])
|
text = enc.decode(out[i])
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
|
|||||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
||||||
if start_token is None:
|
if start_token is None:
|
||||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||||
context = torch.tensor(context, device=device)
|
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
|
||||||
else:
|
else:
|
||||||
assert context is None, 'Specify exactly one of start_token and context!'
|
assert context is None, 'Specify exactly one of start_token and context!'
|
||||||
context = torch.full((batch_size, 1), start_token, device=device)
|
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
|
||||||
prev = context
|
prev = context
|
||||||
output = context
|
output = context
|
||||||
|
past = None
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(length):
|
for i in trange(length):
|
||||||
logits, past = model(prev, past=past)
|
logits, past = model(prev, past=past)
|
||||||
logits = logits[:, -1, :] / temperature
|
logits = logits[:, -1, :] / temperature
|
||||||
logits = top_k_logits(logits, k=top_k)
|
logits = top_k_logits(logits, k=top_k)
|
||||||
prev = torch.multinomial(logits, 1)
|
log_probs = F.softmax(logits, dim=-1)
|
||||||
|
prev = torch.multinomial(log_probs, num_samples=1)
|
||||||
output = torch.cat((output, prev), dim=1)
|
output = torch.cat((output, prev), dim=1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -50,7 +54,7 @@ def interact_model():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size == -1:
|
||||||
args.batch_size = 1
|
args.batch_size = 1
|
||||||
assert args.nsamples % args.batch_size == 0
|
assert args.nsamples % args.batch_size == 0
|
||||||
|
|
||||||
@@ -61,6 +65,8 @@ def interact_model():
|
|||||||
|
|
||||||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
if args.length == -1:
|
if args.length == -1:
|
||||||
args.length = model.config.n_ctx // 2
|
args.length = model.config.n_ctx // 2
|
||||||
@@ -81,7 +87,7 @@ def interact_model():
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
temperature=args.temperature, top_k=args.top_k, device=device
|
temperature=args.temperature, top_k=args.top_k, device=device
|
||||||
)
|
)
|
||||||
out = out[:, len(context_tokens):]
|
out = out[:, len(context_tokens):].tolist()
|
||||||
for i in range(args.batch_size):
|
for i in range(args.batch_size):
|
||||||
generated += 1
|
generated += 1
|
||||||
text = enc.decode(out[i])
|
text = enc.decode(out[i])
|
||||||
|
|||||||
@@ -244,10 +244,10 @@ class Attention(nn.Module):
|
|||||||
key = self.split_heads(key, k=True)
|
key = self.split_heads(key, k=True)
|
||||||
value = self.split_heads(value)
|
value = self.split_heads(value)
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose to have same shapes
|
||||||
key = torch.cat((past_key, key), dim=-2)
|
key = torch.cat((past_key, key), dim=-1)
|
||||||
value = torch.cat((past_value, value), dim=-2)
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
present = torch.stack((key, value))
|
present = torch.stack((key.transpose(-2, -1), value))
|
||||||
a = self._attn(query, key, value)
|
a = self._attn(query, key, value)
|
||||||
a = self.merge_heads(a)
|
a = self.merge_heads(a)
|
||||||
a = self.c_proj(a)
|
a = self.c_proj(a)
|
||||||
@@ -278,7 +278,7 @@ class Block(nn.Module):
|
|||||||
self.mlp = MLP(4 * nx, config)
|
self.mlp = MLP(4 * nx, config)
|
||||||
|
|
||||||
def forward(self, x, layer_past=None):
|
def forward(self, x, layer_past=None):
|
||||||
a, present = self.attn(self.ln_1(x), layer_past=past)
|
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
|
||||||
x = x + a
|
x = x + a
|
||||||
m = self.mlp(self.ln_2(x))
|
m = self.mlp(self.ln_2(x))
|
||||||
x = x + m
|
x = x + m
|
||||||
@@ -531,7 +531,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
past_length = 0
|
past_length = 0
|
||||||
past = [None] * len(self.h)
|
past = [None] * len(self.h)
|
||||||
else:
|
else:
|
||||||
past[0][0].size(-2)
|
past_length = past[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
||||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user