diff --git a/examples/pplm/run_pplm.py b/examples/pplm/run_pplm.py index 8c405b56ad..b334a0098c 100644 --- a/examples/pplm/run_pplm.py +++ b/examples/pplm/run_pplm.py @@ -344,6 +344,7 @@ def full_text_generation( gamma=1.5, gm_scale=0.9, kl_scale=0.01, + repetition_penalty=1.0, **kwargs ): classifier, class_id = get_classifier(discrim, class_label, device) @@ -368,7 +369,14 @@ def full_text_generation( raise Exception("Specify either a bag of words or a discriminator") unpert_gen_tok_text, _, _ = generate_text_pplm( - model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False + model=model, + tokenizer=tokenizer, + context=context, + device=device, + length=length, + sample=sample, + perturb=False, + repetition_penalty=repetition_penalty, ) if device == "cuda": torch.cuda.empty_cache() @@ -401,6 +409,7 @@ def full_text_generation( gamma=gamma, gm_scale=gm_scale, kl_scale=kl_scale, + repetition_penalty=repetition_penalty, ) pert_gen_tok_texts.append(pert_gen_tok_text) if classifier is not None: @@ -437,6 +446,7 @@ def generate_text_pplm( gamma=1.5, gm_scale=0.9, kl_scale=0.01, + repetition_penalty=1.0, ): output_so_far = None if context: @@ -508,6 +518,13 @@ def generate_text_pplm( pert_logits, past, pert_all_hidden = model(last, past=pert_past) pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST + + for token_idx in set(output_so_far[0].tolist()): + if pert_logits[0, token_idx] < 0: + pert_logits[0, token_idx] *= repetition_penalty + else: + pert_logits[0, token_idx] /= repetition_penalty + pert_probs = F.softmax(pert_logits, dim=-1) if classifier is not None: @@ -588,6 +605,7 @@ def run_pplm_example( seed=0, no_cuda=False, colorama=False, + repetition_penalty=1.0, ): # set Random seed torch.manual_seed(seed) @@ -655,6 +673,7 @@ def run_pplm_example( gamma=gamma, gm_scale=gm_scale, kl_scale=kl_scale, + repetition_penalty=repetition_penalty, ) # untokenize unperturbed text @@ -767,6 +786,9 @@ if __name__ == "__main__": parser.add_argument("--seed", type=int, default=0) parser.add_argument("--no_cuda", action="store_true", help="no cuda") parser.add_argument("--colorama", action="store_true", help="colors keywords") + parser.add_argument( + "--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition", + ) args = parser.parse_args() run_pplm_example(**vars(args))