Black 20 release
This commit is contained in:
@@ -698,7 +698,9 @@ def run_pplm_example(
|
||||
for word_id in pert_gen_tok_text.tolist()[0]:
|
||||
if word_id in bow_word_ids:
|
||||
pert_gen_text += "{}{}{}".format(
|
||||
colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
|
||||
colorama.Fore.RED,
|
||||
tokenizer.decode([word_id]),
|
||||
colorama.Style.RESET_ALL,
|
||||
)
|
||||
else:
|
||||
pert_gen_text += tokenizer.decode([word_id])
|
||||
@@ -729,7 +731,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
|
||||
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
|
||||
parser.add_argument(
|
||||
"--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
@@ -751,13 +756,22 @@ if __name__ == "__main__":
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
|
||||
"--discrim_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Weights for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
|
||||
"--discrim_meta",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Meta information for the generic discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_label", type=int, default=-1, help="Class label used for the discriminator",
|
||||
"--class_label",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
@@ -773,7 +787,10 @@ if __name__ == "__main__":
|
||||
help="Length of past which is being optimized; 0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length", type=int, default=1, help="Length of future to optimize over",
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true", help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
@@ -783,7 +800,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
parser.add_argument(
|
||||
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
"--repetition_penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -242,7 +242,12 @@ def train_discriminator(
|
||||
|
||||
text = torchtext_data.Field()
|
||||
label = torchtext_data.Field(sequential=False)
|
||||
train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True,)
|
||||
train_data, val_data, test_data = datasets.SST.splits(
|
||||
text,
|
||||
label,
|
||||
fine_grained=True,
|
||||
train_subtrees=True,
|
||||
)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
|
||||
@@ -41,7 +41,9 @@ from transformers import (
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,7 +199,10 @@ def main():
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
|
||||
logger.warning(
|
||||
"device: %s, n_gpu: %s, 16-bits training: %s", args.device, args.n_gpu, args.fp16,
|
||||
"device: %s, n_gpu: %s, 16-bits training: %s",
|
||||
args.device,
|
||||
args.n_gpu,
|
||||
args.fp16,
|
||||
)
|
||||
|
||||
set_seed(args)
|
||||
|
||||
Reference in New Issue
Block a user