Added support for generic discriminators
This commit is contained in:
@@ -14,17 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# TODO: add code for training a custom discriminator
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Example command with bag of words:
|
Example command with bag of words:
|
||||||
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
||||||
|
|
||||||
Example command with discriminator:
|
Example command with discriminator:
|
||||||
python examples/run_pplm.py -D sentiment --label_class 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
from operator import add
|
from operator import add
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -121,7 +120,7 @@ def perturb_past(
|
|||||||
grad_norms=None,
|
grad_norms=None,
|
||||||
stepsize=0.01,
|
stepsize=0.01,
|
||||||
classifier=None,
|
classifier=None,
|
||||||
label_class=None,
|
class_label=None,
|
||||||
one_hot_bows_vectors=None,
|
one_hot_bows_vectors=None,
|
||||||
loss_type=0,
|
loss_type=0,
|
||||||
num_iterations=3,
|
num_iterations=3,
|
||||||
@@ -230,7 +229,7 @@ def perturb_past(
|
|||||||
prediction = classifier(new_accumulated_hidden /
|
prediction = classifier(new_accumulated_hidden /
|
||||||
(curr_length + 1 + horizon_length))
|
(curr_length + 1 + horizon_length))
|
||||||
|
|
||||||
label = torch.tensor([label_class], device=device,
|
label = torch.tensor([class_label], device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
discrim_loss = ce_loss(prediction, label)
|
discrim_loss = ce_loss(prediction, label)
|
||||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||||
@@ -244,7 +243,8 @@ def perturb_past(
|
|||||||
unpert_probs + SMALL_CONST *
|
unpert_probs + SMALL_CONST *
|
||||||
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||||
)
|
)
|
||||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
|
||||||
|
device).detach()
|
||||||
corrected_probs = probs + correction.detach()
|
corrected_probs = probs + correction.detach()
|
||||||
kl_loss = kl_scale * (
|
kl_loss = kl_scale * (
|
||||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||||
@@ -273,7 +273,8 @@ def perturb_past(
|
|||||||
# normalize gradients
|
# normalize gradients
|
||||||
grad = [
|
grad = [
|
||||||
-stepsize *
|
-stepsize *
|
||||||
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
(p_.grad * window_mask / grad_norms[
|
||||||
|
index] ** gamma).data.cpu().numpy()
|
||||||
for index, p_ in enumerate(curr_perturbation)
|
for index, p_ in enumerate(curr_perturbation)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -301,7 +302,7 @@ def perturb_past(
|
|||||||
|
|
||||||
|
|
||||||
def get_classifier(
|
def get_classifier(
|
||||||
name: Optional[str], label_class: Union[str, int],
|
name: Optional[str], class_label: Union[str, int],
|
||||||
device: str
|
device: str
|
||||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||||
if name is None:
|
if name is None:
|
||||||
@@ -312,26 +313,29 @@ def get_classifier(
|
|||||||
class_size=params['class_size'],
|
class_size=params['class_size'],
|
||||||
embed_size=params['embed_size']
|
embed_size=params['embed_size']
|
||||||
).to(device)
|
).to(device)
|
||||||
|
if "url" in params:
|
||||||
resolved_archive_file = cached_path(params["url"])
|
resolved_archive_file = cached_path(params["url"])
|
||||||
|
else:
|
||||||
|
resolved_archive_file = params["path"]
|
||||||
classifier.load_state_dict(
|
classifier.load_state_dict(
|
||||||
torch.load(resolved_archive_file, map_location=device))
|
torch.load(resolved_archive_file, map_location=device))
|
||||||
classifier.eval()
|
classifier.eval()
|
||||||
|
|
||||||
if isinstance(label_class, str):
|
if isinstance(class_label, str):
|
||||||
if label_class in params["class_vocab"]:
|
if class_label in params["class_vocab"]:
|
||||||
label_id = params["class_vocab"][label_class]
|
label_id = params["class_vocab"][class_label]
|
||||||
else:
|
else:
|
||||||
label_id = params["default_class"]
|
label_id = params["default_class"]
|
||||||
print("label_class {} not in class_vocab".format(label_class))
|
print("class_label {} not in class_vocab".format(class_label))
|
||||||
print("available values are: {}".format(params["class_vocab"]))
|
print("available values are: {}".format(params["class_vocab"]))
|
||||||
print("using default class {}".format(label_id))
|
print("using default class {}".format(label_id))
|
||||||
|
|
||||||
elif isinstance(label_class, int):
|
elif isinstance(class_label, int):
|
||||||
if label_class in set(params["class_vocab"].values()):
|
if class_label in set(params["class_vocab"].values()):
|
||||||
label_id = label_class
|
label_id = class_label
|
||||||
else:
|
else:
|
||||||
label_id = params["default_class"]
|
label_id = params["default_class"]
|
||||||
print("label_class {} not in class_vocab".format(label_class))
|
print("class_label {} not in class_vocab".format(class_label))
|
||||||
print("available values are: {}".format(params["class_vocab"]))
|
print("available values are: {}".format(params["class_vocab"]))
|
||||||
print("using default class {}".format(label_id))
|
print("using default class {}".format(label_id))
|
||||||
|
|
||||||
@@ -379,7 +383,7 @@ def full_text_generation(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
sample=True,
|
sample=True,
|
||||||
discrim=None,
|
discrim=None,
|
||||||
label_class=None,
|
class_label=None,
|
||||||
bag_of_words=None,
|
bag_of_words=None,
|
||||||
length=100,
|
length=100,
|
||||||
grad_length=10000,
|
grad_length=10000,
|
||||||
@@ -397,7 +401,7 @@ def full_text_generation(
|
|||||||
):
|
):
|
||||||
classifier, class_id = get_classifier(
|
classifier, class_id = get_classifier(
|
||||||
discrim,
|
discrim,
|
||||||
label_class,
|
class_label,
|
||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -443,7 +447,7 @@ def full_text_generation(
|
|||||||
perturb=True,
|
perturb=True,
|
||||||
bow_indices=bow_indices,
|
bow_indices=bow_indices,
|
||||||
classifier=classifier,
|
classifier=classifier,
|
||||||
label_class=class_id,
|
class_label=class_id,
|
||||||
loss_type=loss_type,
|
loss_type=loss_type,
|
||||||
length=length,
|
length=length,
|
||||||
grad_length=grad_length,
|
grad_length=grad_length,
|
||||||
@@ -477,7 +481,7 @@ def generate_text_pplm(
|
|||||||
sample=True,
|
sample=True,
|
||||||
perturb=True,
|
perturb=True,
|
||||||
classifier=None,
|
classifier=None,
|
||||||
label_class=None,
|
class_label=None,
|
||||||
bow_indices=None,
|
bow_indices=None,
|
||||||
loss_type=0,
|
loss_type=0,
|
||||||
length=100,
|
length=100,
|
||||||
@@ -545,7 +549,7 @@ def generate_text_pplm(
|
|||||||
grad_norms=grad_norms,
|
grad_norms=grad_norms,
|
||||||
stepsize=current_stepsize,
|
stepsize=current_stepsize,
|
||||||
classifier=classifier,
|
classifier=classifier,
|
||||||
label_class=label_class,
|
class_label=class_label,
|
||||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||||
loss_type=loss_type,
|
loss_type=loss_type,
|
||||||
num_iterations=num_iterations,
|
num_iterations=num_iterations,
|
||||||
@@ -567,7 +571,7 @@ def generate_text_pplm(
|
|||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||||
label = torch.tensor([label_class], device=device,
|
label = torch.tensor([class_label], device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
unpert_discrim_loss = ce_loss(prediction, label)
|
unpert_discrim_loss = ce_loss(prediction, label)
|
||||||
print(
|
print(
|
||||||
@@ -613,6 +617,20 @@ def generate_text_pplm(
|
|||||||
return output_so_far, unpert_discrim_loss, loss_in_time
|
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||||
|
|
||||||
|
|
||||||
|
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||||
|
if discrim_weights is None:
|
||||||
|
raise ValueError('When using a generic discriminator, '
|
||||||
|
'discrim_weights need to be specified')
|
||||||
|
if discrim_meta is None:
|
||||||
|
raise ValueError('When using a generic discriminator, '
|
||||||
|
'discrim_meta need to be specified')
|
||||||
|
|
||||||
|
with open(discrim_meta, 'r') as discrim_meta_file:
|
||||||
|
meta = json.load(discrim_meta_file)
|
||||||
|
meta['path'] = discrim_weights
|
||||||
|
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
||||||
|
|
||||||
|
|
||||||
def run_model():
|
def run_model():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -636,11 +654,15 @@ def run_model():
|
|||||||
"-D",
|
"-D",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
choices=("clickbait", "sentiment", "toxicity"),
|
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||||
help="Discriminator to use for loss-type 2",
|
help="Discriminator to use",
|
||||||
)
|
)
|
||||||
|
parser.add_argument('--discrim_weights', type=str, default=None,
|
||||||
|
help='Weights for the generic discriminator')
|
||||||
|
parser.add_argument('--discrim_meta', type=str, default=None,
|
||||||
|
help='Meta information for the generic discriminator')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_class",
|
"--class_label",
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
help="Class label used for the discriminator",
|
help="Class label used for the discriminator",
|
||||||
@@ -697,6 +719,9 @@ def run_model():
|
|||||||
# set the device
|
# set the device
|
||||||
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||||
|
|
||||||
|
if args.discrim == 'generic':
|
||||||
|
set_generic_model_params(args.discrim_weights, args.discrim_meta)
|
||||||
|
|
||||||
# load pretrained model
|
# load pretrained model
|
||||||
model = GPT2LMHeadModel.from_pretrained(
|
model = GPT2LMHeadModel.from_pretrained(
|
||||||
args.model_path,
|
args.model_path,
|
||||||
|
|||||||
Reference in New Issue
Block a user