Update quality tooling for formatting (#21480)

* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
This commit is contained in:
Sylvain Gugger
2023-02-06 18:10:56 -05:00
committed by GitHub
parent b7bb2b59f7
commit 6f79d26442
1211 changed files with 1532 additions and 2687 deletions

View File

@@ -30,10 +30,10 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from pplm_classification_head import ClassificationHead
from torch import nn
from tqdm import trange
from pplm_classification_head import ClassificationHead
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.file_utils import cached_path
@@ -345,7 +345,7 @@ def full_text_generation(
gm_scale=0.9,
kl_scale=0.01,
repetition_penalty=1.0,
**kwargs
**kwargs,
):
classifier, class_id = get_classifier(discrim, class_label, device)
@@ -463,7 +463,6 @@ def generate_text_pplm(
unpert_discrim_loss = 0
loss_in_time = []
for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current_token
@@ -547,7 +546,6 @@ def generate_text_pplm(
# Fuse the modified model and original model
if perturb:
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs**gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST

View File

@@ -26,12 +26,12 @@ import torch
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from pplm_classification_head import ClassificationHead
from torch import nn
from torchtext import data as torchtext_data
from torchtext import datasets
from tqdm import tqdm, trange
from pplm_classification_head import ClassificationHead
from transformers import GPT2LMHeadModel, GPT2Tokenizer