Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -698,7 +698,9 @@ def run_pplm_example(
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids:
pert_gen_text += "{}{}{}".format(
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
colorama.Fore.RED,
tokenizer.decode([word_id]),
colorama.Style.RESET_ALL,
)
else:
pert_gen_text += tokenizer.decode([word_id])
@@ -729,7 +731,10 @@ if __name__ == "__main__":
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
parser.add_argument(
"--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--bag_of_words",
@@ -751,13 +756,22 @@ if __name__ == "__main__":
help="Discriminator to use",
)
parser.add_argument(
"--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
"--discrim_weights",
type=str,
default=None,
help="Weights for the generic discriminator",
)
parser.add_argument(
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
"--discrim_meta",
type=str,
default=None,
help="Meta information for the generic discriminator",
)
parser.add_argument(
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--stepsize", type=float, default=0.02)
@@ -773,7 +787,10 @@ if __name__ == "__main__":
help="Length of past which is being optimized; 0 corresponds to infinite window length",
)
parser.add_argument(
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument("--decay", action="store_true", help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
@@ -783,7 +800,10 @@ if __name__ == "__main__":
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true", help="colors keywords")
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
"--repetition_penalty",
type=float,
default=1.0,
help="Penalize repetition. More than 1.0 -> less repetition",
)
args = parser.parse_args()