fix example - masking

This commit is contained in:
thomwolf
2019-02-18 10:50:30 +01:00
parent fbb248a2e4
commit 690a0dbf36
2 changed files with 24 additions and 24 deletions

View File

@@ -22,7 +22,7 @@ def top_k_logits(logits, k):
min_values = values[:, -1]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
@@ -38,11 +38,14 @@ def sample_sequence(model, length, start_token=None, batch_size=None, context=No
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
if sample:
prev = torch.multinomial(log_probs, num_samples=1)
else:
_, prev = torch.topk(log_probs, k=1, dim=-1)
output = torch.cat((output, prev), dim=1)
return output
def interact_model():
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=0)
@@ -51,6 +54,7 @@ def interact_model():
parser.add_argument("--length", type=int, default=-1)
parser.add_argument("--temperature", type=int, default=1)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
args = parser.parse_args()
print(args)
@@ -73,17 +77,19 @@ def interact_model():
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while True:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
while not args.unconditional:
if not args.unconditional:
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens,
context=context_tokens if not args.unconditional else None,
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
@@ -96,5 +102,4 @@ def interact_model():
print("=" * 80)
if __name__ == '__main__':
interact_model()
run_model()