From 690a0dbf36b2f2dea76144d425705fcd5442087b Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 18 Feb 2019 10:50:30 +0100 Subject: [PATCH] fix example - masking --- ...ive_conditional_samples.py => run_gpt2.py} | 27 +++++++++++-------- pytorch_pretrained_bert/modeling_gpt2.py | 21 ++++++--------- 2 files changed, 24 insertions(+), 24 deletions(-) rename examples/{run_gpt2_interactive_conditional_samples.py => run_gpt2.py} (81%) diff --git a/examples/run_gpt2_interactive_conditional_samples.py b/examples/run_gpt2.py similarity index 81% rename from examples/run_gpt2_interactive_conditional_samples.py rename to examples/run_gpt2.py index b54ff94a43..4b34d82490 100644 --- a/examples/run_gpt2_interactive_conditional_samples.py +++ b/examples/run_gpt2.py @@ -22,7 +22,7 @@ def top_k_logits(logits, k): min_values = values[:, -1] return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) -def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): +def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) @@ -38,11 +38,14 @@ def sample_sequence(model, length, start_token=None, batch_size=None, context=No logits = logits[:, -1, :] / temperature logits = top_k_logits(logits, k=top_k) log_probs = F.softmax(logits, dim=-1) - prev = torch.multinomial(log_probs, num_samples=1) + if sample: + prev = torch.multinomial(log_probs, num_samples=1) + else: + _, prev = torch.topk(log_probs, k=1, dim=-1) output = torch.cat((output, prev), dim=1) return output -def interact_model(): +def run_model(): parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') parser.add_argument("--seed", type=int, default=0) @@ -51,6 +54,7 @@ def interact_model(): parser.add_argument("--length", type=int, default=-1) parser.add_argument("--temperature", type=int, default=1) parser.add_argument("--top_k", type=int, default=0) + parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') args = parser.parse_args() print(args) @@ -73,17 +77,19 @@ def interact_model(): elif args.length > model.config.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) - while True: - raw_text = input("Model prompt >>> ") - while not raw_text: - print('Prompt should not be empty!') + while not args.unconditional: + if not args.unconditional: raw_text = input("Model prompt >>> ") - context_tokens = enc.encode(raw_text) + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("Model prompt >>> ") + context_tokens = enc.encode(raw_text) generated = 0 for _ in range(args.nsamples // args.batch_size): out = sample_sequence( model=model, length=args.length, - context=context_tokens, + context=context_tokens if not args.unconditional else None, + start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None, batch_size=args.batch_size, temperature=args.temperature, top_k=args.top_k, device=device ) @@ -96,5 +102,4 @@ def interact_model(): print("=" * 80) if __name__ == '__main__': - interact_model() - + run_model() diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index 04aa60fe1f..b72fd4ac59 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -87,10 +87,6 @@ def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): if len(l) >= 2: num = int(l[1]) pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) try: assert pointer.shape == array.shape except AssertionError as e: @@ -216,10 +212,9 @@ class Attention(nn.Module): w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) - # w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights - # XD: self.b may be larger than w, so we need to crop it - b = self.bias[:, :, : w.size(-2), : w.size(-1)] - w = w * b + -1e10 * (1 - b) + nd, ns = w.size(-2), w.size(-1) + b = self.bias[:, :, ns-nd:ns, :ns] + w = w * b - 1e10 * (1 - b) w = nn.Softmax(dim=-1)(w) return torch.matmul(w, v) @@ -233,9 +228,9 @@ class Attention(nn.Module): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states if k: - return x.permute(0, 2, 3, 1) + return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) else: - return x.permute(0, 2, 1, 3) + return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def forward(self, x, layer_past=None): x = self.c_attn(x) @@ -244,10 +239,10 @@ class Attention(nn.Module): key = self.split_heads(key, k=True) value = self.split_heads(value) if layer_past is not None: - past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose to have same shapes + past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below key = torch.cat((past_key, key), dim=-1) value = torch.cat((past_value, value), dim=-2) - present = torch.stack((key.transpose(-2, -1), value)) + present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking a = self._attn(query, key, value) a = self.merge_heads(a) a = self.c_proj(a) @@ -522,7 +517,7 @@ class GPT2Model(GPT2PreTrainedModel): self.wpe = nn.Embedding(config.n_positions, config.n_embd) block = Block(config.n_ctx, config, scale=True) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) - self.ln_f = LayerNorm(config.n_embd) + self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.apply(self.init_weights)