Now run_pplm works on cpu. Identical output as before (when using gpu).
This commit is contained in:
@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
|
||||
}
|
||||
|
||||
|
||||
def to_var(x, requires_grad=False, volatile=False):
|
||||
if torch.cuda.is_available():
|
||||
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
|
||||
if torch.cuda.is_available() and device == 'cuda':
|
||||
x = x.cuda()
|
||||
elif device != 'cuda':
|
||||
x = x.to(device)
|
||||
return Variable(x, requires_grad=requires_grad, volatile=volatile)
|
||||
|
||||
|
||||
@@ -182,7 +184,7 @@ def perturb_past(
|
||||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
curr_perturbation = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True)
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
|
||||
@@ -290,7 +292,7 @@ def perturb_past(
|
||||
|
||||
# apply the accumulated perturbations to the past
|
||||
grad_accumulator = [
|
||||
to_var(torch.from_numpy(p_), requires_grad=True)
|
||||
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
|
||||
for p_ in grad_accumulator
|
||||
]
|
||||
pert_past = list(map(add, past, grad_accumulator))
|
||||
@@ -300,7 +302,7 @@ def perturb_past(
|
||||
|
||||
def get_classifier(
|
||||
name: Optional[str], label_class: Union[str, int],
|
||||
device: Union[str, torch.device]
|
||||
device: str
|
||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||
if name is None:
|
||||
return None, None
|
||||
@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
|
||||
return bow_indices
|
||||
|
||||
|
||||
def build_bows_one_hot_vectors(bow_indices):
|
||||
def build_bows_one_hot_vectors(bow_indices, device='cuda'):
|
||||
if bow_indices is None:
|
||||
return None
|
||||
|
||||
one_hot_bows_vectors = []
|
||||
for single_bow in bow_indices:
|
||||
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
|
||||
single_bow = torch.tensor(single_bow).cuda()
|
||||
single_bow = torch.tensor(single_bow).to(device)
|
||||
num_words = single_bow.shape[0]
|
||||
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).cuda()
|
||||
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
|
||||
one_hot_bow.scatter_(1, single_bow, 1)
|
||||
one_hot_bows_vectors.append(one_hot_bow)
|
||||
return one_hot_bows_vectors
|
||||
@@ -425,7 +427,8 @@ def full_text_generation(
|
||||
length=length,
|
||||
perturb=False
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pert_gen_tok_texts = []
|
||||
discrim_losses = []
|
||||
@@ -460,7 +463,8 @@ def full_text_generation(
|
||||
discrim_losses.append(discrim_loss.data.cpu().numpy())
|
||||
losses_in_time.append(loss_in_time)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
|
||||
|
||||
@@ -496,7 +500,7 @@ def generate_text_pplm(
|
||||
)
|
||||
|
||||
# collect one hot vectors for bags of words
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
|
||||
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
|
||||
|
||||
grad_norms = None
|
||||
last = None
|
||||
@@ -563,7 +567,7 @@ def generate_text_pplm(
|
||||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([label_class], device='cuda',
|
||||
label = torch.tensor([label_class], device=device,
|
||||
dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
print(
|
||||
|
||||
Reference in New Issue
Block a user