Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -30,10 +30,10 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from torch import nn
|
||||
from tqdm import trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
|
||||
@@ -345,7 +345,7 @@ def full_text_generation(
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||
|
||||
@@ -463,7 +463,6 @@ def generate_text_pplm(
|
||||
unpert_discrim_loss = 0
|
||||
loss_in_time = []
|
||||
for i in trange(length, ascii=True):
|
||||
|
||||
# Get past/probs for current output, except for last word
|
||||
# Note that GPT takes 2 inputs: past + current_token
|
||||
|
||||
@@ -547,7 +546,6 @@ def generate_text_pplm(
|
||||
|
||||
# Fuse the modified model and original model
|
||||
if perturb:
|
||||
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = (pert_probs**gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
|
||||
@@ -26,12 +26,12 @@ import torch
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as data
|
||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from torch import nn
|
||||
from torchtext import data as torchtext_data
|
||||
from torchtext import datasets
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pplm_classification_head import ClassificationHead
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user