examples testing
This commit is contained in:
@@ -4,7 +4,9 @@ import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
|
||||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
|
||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.tensor(context, device=device)
|
||||
context = torch.tensor(context, device=device, dtype=torch.long)
|
||||
else:
|
||||
assert context is None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.full((batch_size, 1), start_token, device=device)
|
||||
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
|
||||
prev = context
|
||||
output = context
|
||||
past = None
|
||||
with torch.no_grad():
|
||||
for i in range(length):
|
||||
for i in trange(length):
|
||||
logits, past = model(prev, past=past)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
prev = torch.multinomial(logits, 1)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
output = torch.cat((output, prev), dim=1)
|
||||
return output
|
||||
|
||||
@@ -57,6 +61,8 @@ def sample_model():
|
||||
|
||||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.length == -1:
|
||||
args.length = model.config.n_ctx
|
||||
@@ -71,6 +77,7 @@ def sample_model():
|
||||
batch_size=args.batch_size,
|
||||
temperature=args.temperature, top_k=args.top_k, device=device
|
||||
)
|
||||
out = out.tolist()
|
||||
for i in range(args.batch_size):
|
||||
generated += args.batch_size
|
||||
text = enc.decode(out[i])
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import trange
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
|
||||
@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
|
||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.tensor(context, device=device)
|
||||
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
|
||||
else:
|
||||
assert context is None, 'Specify exactly one of start_token and context!'
|
||||
context = torch.full((batch_size, 1), start_token, device=device)
|
||||
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
|
||||
prev = context
|
||||
output = context
|
||||
past = None
|
||||
with torch.no_grad():
|
||||
for i in range(length):
|
||||
for i in trange(length):
|
||||
logits, past = model(prev, past=past)
|
||||
logits = logits[:, -1, :] / temperature
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
prev = torch.multinomial(logits, 1)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
output = torch.cat((output, prev), dim=1)
|
||||
return output
|
||||
|
||||
@@ -50,7 +54,7 @@ def interact_model():
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.batch_size is None:
|
||||
if args.batch_size == -1:
|
||||
args.batch_size = 1
|
||||
assert args.nsamples % args.batch_size == 0
|
||||
|
||||
@@ -61,6 +65,8 @@ def interact_model():
|
||||
|
||||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.length == -1:
|
||||
args.length = model.config.n_ctx // 2
|
||||
@@ -81,7 +87,7 @@ def interact_model():
|
||||
batch_size=args.batch_size,
|
||||
temperature=args.temperature, top_k=args.top_k, device=device
|
||||
)
|
||||
out = out[:, len(context_tokens):]
|
||||
out = out[:, len(context_tokens):].tolist()
|
||||
for i in range(args.batch_size):
|
||||
generated += 1
|
||||
text = enc.decode(out[i])
|
||||
|
||||
Reference in New Issue
Block a user