Now run_pplm works on cpu. Identical output as before (when using gpu).
This commit is contained in:
@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def to_var(x, requires_grad=False, volatile=False):
|
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() and device == 'cuda':
|
||||||
x = x.cuda()
|
x = x.cuda()
|
||||||
|
elif device != 'cuda':
|
||||||
|
x = x.to(device)
|
||||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
||||||
|
|
||||||
|
|
||||||
@@ -182,7 +184,7 @@ def perturb_past(
|
|||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
print("Iteration ", i + 1)
|
print("Iteration ", i + 1)
|
||||||
curr_perturbation = [
|
curr_perturbation = [
|
||||||
to_var(torch.from_numpy(p_), requires_grad=True)
|
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||||
for p_ in grad_accumulator
|
for p_ in grad_accumulator
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -290,7 +292,7 @@ def perturb_past(
|
|||||||
|
|
||||||
# apply the accumulated perturbations to the past
|
# apply the accumulated perturbations to the past
|
||||||
grad_accumulator = [
|
grad_accumulator = [
|
||||||
to_var(torch.from_numpy(p_), requires_grad=True)
|
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||||
for p_ in grad_accumulator
|
for p_ in grad_accumulator
|
||||||
]
|
]
|
||||||
pert_past = list(map(add, past, grad_accumulator))
|
pert_past = list(map(add, past, grad_accumulator))
|
||||||
@@ -300,7 +302,7 @@ def perturb_past(
|
|||||||
|
|
||||||
def get_classifier(
|
def get_classifier(
|
||||||
name: Optional[str], label_class: Union[str, int],
|
name: Optional[str], label_class: Union[str, int],
|
||||||
device: Union[str, torch.device]
|
device: str
|
||||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||||
if name is None:
|
if name is None:
|
||||||
return None, None
|
return None, None
|
||||||
@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
|
|||||||
return bow_indices
|
return bow_indices
|
||||||
|
|
||||||
|
|
||||||
def build_bows_one_hot_vectors(bow_indices):
|
def build_bows_one_hot_vectors(bow_indices, device='cuda'):
|
||||||
if bow_indices is None:
|
if bow_indices is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
one_hot_bows_vectors = []
|
one_hot_bows_vectors = []
|
||||||
for single_bow in bow_indices:
|
for single_bow in bow_indices:
|
||||||
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
||||||
single_bow = torch.tensor(single_bow).cuda()
|
single_bow = torch.tensor(single_bow).to(device)
|
||||||
num_words = single_bow.shape[0]
|
num_words = single_bow.shape[0]
|
||||||
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).cuda()
|
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
|
||||||
one_hot_bow.scatter_(1, single_bow, 1)
|
one_hot_bow.scatter_(1, single_bow, 1)
|
||||||
one_hot_bows_vectors.append(one_hot_bow)
|
one_hot_bows_vectors.append(one_hot_bow)
|
||||||
return one_hot_bows_vectors
|
return one_hot_bows_vectors
|
||||||
@@ -425,7 +427,8 @@ def full_text_generation(
|
|||||||
length=length,
|
length=length,
|
||||||
perturb=False
|
perturb=False
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
if device == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
pert_gen_tok_texts = []
|
pert_gen_tok_texts = []
|
||||||
discrim_losses = []
|
discrim_losses = []
|
||||||
@@ -460,7 +463,8 @@ def full_text_generation(
|
|||||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||||
losses_in_time.append(loss_in_time)
|
losses_in_time.append(loss_in_time)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
if device == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||||
|
|
||||||
@@ -496,7 +500,7 @@ def generate_text_pplm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# collect one hot vectors for bags of words
|
# collect one hot vectors for bags of words
|
||||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
|
||||||
|
|
||||||
grad_norms = None
|
grad_norms = None
|
||||||
last = None
|
last = None
|
||||||
@@ -563,7 +567,7 @@ def generate_text_pplm(
|
|||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = torch.nn.CrossEntropyLoss()
|
||||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||||
label = torch.tensor([label_class], device='cuda',
|
label = torch.tensor([label_class], device=device,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
unpert_discrim_loss = ce_loss(prediction, label)
|
unpert_discrim_loss = ce_loss(prediction, label)
|
||||||
print(
|
print(
|
||||||
|
|||||||
Reference in New Issue
Block a user