From fbb248a2e47063a1add70a887ef7b32fe4673180 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 18 Feb 2019 01:28:18 +0100 Subject: [PATCH] examples testing --- .../run_gpt2_generate_unconditional_samples.py | 15 +++++++++++---- ...run_gpt2_interactive_conditional_samples.py | 18 ++++++++++++------ pytorch_pretrained_bert/modeling_gpt2.py | 10 +++++----- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/examples/run_gpt2_generate_unconditional_samples.py b/examples/run_gpt2_generate_unconditional_samples.py index 7300bb2f5e..58fb897279 100644 --- a/examples/run_gpt2_generate_unconditional_samples.py +++ b/examples/run_gpt2_generate_unconditional_samples.py @@ -4,7 +4,9 @@ import argparse import logging import torch +import torch.nn.functional as F import numpy as np +from tqdm import trange from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer @@ -23,18 +25,20 @@ def top_k_logits(logits, k): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' - context = torch.tensor(context, device=device) + context = torch.tensor(context, device=device, dtype=torch.long) else: assert context is None, 'Specify exactly one of start_token and context!' - context = torch.full((batch_size, 1), start_token, device=device) + context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) prev = context output = context + past = None with torch.no_grad(): - for i in range(length): + for i in trange(length): logits, past = model(prev, past=past) logits = logits[:, -1, :] / temperature logits = top_k_logits(logits, k=top_k) - prev = torch.multinomial(logits, 1) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1) output = torch.cat((output, prev), dim=1) return output @@ -57,6 +61,8 @@ def sample_model(): enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) + model.to(device) + model.eval() if args.length == -1: args.length = model.config.n_ctx @@ -71,6 +77,7 @@ def sample_model(): batch_size=args.batch_size, temperature=args.temperature, top_k=args.top_k, device=device ) + out = out.tolist() for i in range(args.batch_size): generated += args.batch_size text = enc.decode(out[i]) diff --git a/examples/run_gpt2_interactive_conditional_samples.py b/examples/run_gpt2_interactive_conditional_samples.py index e631864a27..b54ff94a43 100644 --- a/examples/run_gpt2_interactive_conditional_samples.py +++ b/examples/run_gpt2_interactive_conditional_samples.py @@ -2,8 +2,10 @@ import argparse import logging +from tqdm import trange import torch +import torch.nn.functional as F import numpy as np from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer @@ -23,18 +25,20 @@ def top_k_logits(logits, k): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' - context = torch.tensor(context, device=device) + context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) else: assert context is None, 'Specify exactly one of start_token and context!' - context = torch.full((batch_size, 1), start_token, device=device) + context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) prev = context output = context + past = None with torch.no_grad(): - for i in range(length): + for i in trange(length): logits, past = model(prev, past=past) logits = logits[:, -1, :] / temperature logits = top_k_logits(logits, k=top_k) - prev = torch.multinomial(logits, 1) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1) output = torch.cat((output, prev), dim=1) return output @@ -50,7 +54,7 @@ def interact_model(): args = parser.parse_args() print(args) - if args.batch_size is None: + if args.batch_size == -1: args.batch_size = 1 assert args.nsamples % args.batch_size == 0 @@ -61,6 +65,8 @@ def interact_model(): enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) + model.to(device) + model.eval() if args.length == -1: args.length = model.config.n_ctx // 2 @@ -81,7 +87,7 @@ def interact_model(): batch_size=args.batch_size, temperature=args.temperature, top_k=args.top_k, device=device ) - out = out[:, len(context_tokens):] + out = out[:, len(context_tokens):].tolist() for i in range(args.batch_size): generated += 1 text = enc.decode(out[i]) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index eecb07aa6e..04aa60fe1f 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -244,10 +244,10 @@ class Attention(nn.Module): key = self.split_heads(key, k=True) value = self.split_heads(value) if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=-2) + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose to have same shapes + key = torch.cat((past_key, key), dim=-1) value = torch.cat((past_value, value), dim=-2) - present = torch.stack((key, value)) + present = torch.stack((key.transpose(-2, -1), value)) a = self._attn(query, key, value) a = self.merge_heads(a) a = self.c_proj(a) @@ -278,7 +278,7 @@ class Block(nn.Module): self.mlp = MLP(4 * nx, config) def forward(self, x, layer_past=None): - a, present = self.attn(self.ln_1(x), layer_past=past) + a, present = self.attn(self.ln_1(x), layer_past=layer_past) x = x + a m = self.mlp(self.ln_2(x)) x = x + m @@ -531,7 +531,7 @@ class GPT2Model(GPT2PreTrainedModel): past_length = 0 past = [None] * len(self.h) else: - past[0][0].size(-2) + past_length = past[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)