Black 20 release
This commit is contained in:
@@ -698,7 +698,9 @@ def run_pplm_example(
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += "{}{}{}".format(
|
||||
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
|
||||
colorama.Fore.RED,
|
||||
tokenizer.decode([word_id]),
|
||||
colorama.Style.RESET_ALL,
|
||||
)
|
||||
else:
|
||||
pert_gen_text += tokenizer.decode([word_id])
|
||||
@@ -729,7 +731,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
|
||||
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
|
||||
parser.add_argument(
|
||||
"--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
@@ -751,13 +756,22 @@ if __name__ == "__main__":
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
|
||||
"--discrim_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Weights for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
|
||||
"--discrim_meta",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Meta information for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
||||
"--class_label",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
@@ -773,7 +787,10 @@ if __name__ == "__main__":
|
||||
help="Length of past which is being optimized; 0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true", help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
@@ -783,7 +800,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
parser.add_argument(
|
||||
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
"--repetition_penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user