diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 9ddd42681e..57bed3890f 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -121,17 +121,17 @@ def perturb_past( accumulated_hidden=None, grad_norms=None, stepsize=0.01, + one_hot_bows_vectors=None, classifier=None, class_label=None, - one_hot_bows_vectors=None, loss_type=0, num_iterations=3, - kl_scale=0.01, - window_length=0, horizon_length=1, + window_length=0, decay=False, gamma=1.5, - device='cuda' + kl_scale=0.01, + device='cuda', ): # Generate inital perturbed past grad_accumulator = [ @@ -351,8 +351,7 @@ def get_classifier( def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ - List[ - List[List[int]]]: + List[List[List[int]]]: bow_indices = [] for id_or_path in bag_of_words_ids_or_paths: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: @@ -388,22 +387,22 @@ def full_text_generation( context=None, num_samples=1, device="cuda", - sample=False, + bag_of_words=None, discrim=None, class_label=None, - bag_of_words=None, length=100, - grad_length=10000, stepsize=0.02, - num_iterations=3, temperature=1.0, - gm_scale=0.9, - kl_scale=0.01, top_k=10, - window_length=0, + sample=False, + num_iterations=3, + grad_length=10000, horizon_length=1, + window_length=0, decay=False, gamma=1.5, + gm_scale=0.9, + kl_scale=0.01, **kwargs ): classifier, class_id = get_classifier( @@ -454,24 +453,24 @@ def full_text_generation( tokenizer=tokenizer, context=context, device=device, - sample=sample, perturb=True, bow_indices=bow_indices, classifier=classifier, class_label=class_id, loss_type=loss_type, length=length, - grad_length=grad_length, stepsize=stepsize, - num_iterations=num_iterations, temperature=temperature, - gm_scale=gm_scale, - kl_scale=kl_scale, top_k=top_k, - window_length=window_length, + sample=sample, + num_iterations=num_iterations, + grad_length=grad_length, horizon_length=horizon_length, + window_length=window_length, decay=decay, gamma=gamma, + gm_scale=gm_scale, + kl_scale=kl_scale, ) pert_gen_tok_texts.append(pert_gen_tok_text) if classifier is not None: @@ -490,24 +489,24 @@ def generate_text_pplm( context=None, past=None, device="cuda", - sample=False, perturb=True, + bow_indices=None, classifier=None, class_label=None, - bow_indices=None, loss_type=0, length=100, - grad_length=10000, stepsize=0.02, - num_iterations=3, temperature=1.0, - gm_scale=0.9, - kl_scale=0.01, top_k=10, - window_length=0, + sample=False, + num_iterations=3, + grad_length=10000, horizon_length=1, + window_length=0, decay=False, gamma=1.5, + gm_scale=0.9, + kl_scale=0.01, ): output_so_far = ( torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) @@ -561,17 +560,17 @@ def generate_text_pplm( accumulated_hidden=accumulated_hidden, grad_norms=grad_norms, stepsize=current_stepsize, + one_hot_bows_vectors=one_hot_bows_vectors, classifier=classifier, class_label=class_label, - one_hot_bows_vectors=one_hot_bows_vectors, loss_type=loss_type, num_iterations=num_iterations, - kl_scale=kl_scale, - window_length=window_length, horizon_length=horizon_length, + window_length=window_length, decay=decay, gamma=gamma, - device=device + kl_scale=kl_scale, + device=device, ) loss_in_time.append(loss_this_iter) else: @@ -685,7 +684,7 @@ def run_pplm_example( pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ "pretrained_model" ] - print("discrim = {}, setting pretrained_model " + print("discrim = {}, pretrained_model set " "to discriminator's = {}".format(discrim, pretrained_model)) # load pretrained model @@ -810,6 +809,20 @@ if __name__ == '__main__': default="gpt2-medium", help="pretrained model name or path to local checkpoint", ) + parser.add_argument( + "--cond_text", type=str, default="The lake", + help="Prefix texts to condition on" + ) + parser.add_argument( + "--uncond", action="store_true", + help="Generate from end-of-text as prefix" + ) + parser.add_argument( + "--num_samples", + type=int, + default=1, + help="Number of samples to generate from the modified latents", + ) parser.add_argument( "--bag_of_words", "-B", @@ -837,40 +850,16 @@ if __name__ == '__main__': default=-1, help="Class label used for the discriminator", ) - parser.add_argument("--stepsize", type=float, default=0.02) parser.add_argument("--length", type=int, default=100) - parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--stepsize", type=float, default=0.02) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=10) - parser.add_argument("--gm_scale", type=float, default=0.9) - parser.add_argument("--kl_scale", type=float, default=0.01) - parser.add_argument("--no_cuda", action="store_true", help="no cuda") parser.add_argument( "--sample", action="store_true", help="Generate from end-of-text as prefix" ) - parser.add_argument( - "--uncond", action="store_true", - help="Generate from end-of-text as prefix" - ) - parser.add_argument( - "--cond_text", type=str, default="The lake", - help="Prefix texts to condition on" - ) parser.add_argument("--num_iterations", type=int, default=3) parser.add_argument("--grad_length", type=int, default=10000) - parser.add_argument( - "--num_samples", - type=int, - default=1, - help="Number of samples to generate from the modified latents", - ) - parser.add_argument( - "--horizon_length", - type=int, - default=1, - help="Length of future to optimize over", - ) parser.add_argument( "--window_length", type=int, @@ -878,9 +867,19 @@ if __name__ == '__main__': help="Length of past which is being optimized; " "0 corresponds to infinite window length", ) + parser.add_argument( + "--horizon_length", + type=int, + default=1, + help="Length of future to optimize over", + ) parser.add_argument("--decay", action="store_true", help="whether to decay or not") parser.add_argument("--gamma", type=float, default=1.5) + parser.add_argument("--gm_scale", type=float, default=0.9) + parser.add_argument("--kl_scale", type=float, default=0.01) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--no_cuda", action="store_true", help="no cuda") parser.add_argument("--colorama", action="store_true", help="colors keywords")