diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 77758759d9..0d6b0d635d 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = { } -def to_var(x, requires_grad=False, volatile=False): - if torch.cuda.is_available(): +def to_var(x, requires_grad=False, volatile=False, device='cuda'): + if torch.cuda.is_available() and device == 'cuda': x = x.cuda() + elif device != 'cuda': + x = x.to(device) return Variable(x, requires_grad=requires_grad, volatile=volatile) @@ -182,7 +184,7 @@ def perturb_past( for i in range(num_iterations): print("Iteration ", i + 1) curr_perturbation = [ - to_var(torch.from_numpy(p_), requires_grad=True) + to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator ] @@ -290,7 +292,7 @@ def perturb_past( # apply the accumulated perturbations to the past grad_accumulator = [ - to_var(torch.from_numpy(p_), requires_grad=True) + to_var(torch.from_numpy(p_), requires_grad=True, device=device) for p_ in grad_accumulator ] pert_past = list(map(add, past, grad_accumulator)) @@ -300,7 +302,7 @@ def perturb_past( def get_classifier( name: Optional[str], label_class: Union[str, int], - device: Union[str, torch.device] + device: str ) -> Tuple[Optional[ClassificationHead], Optional[int]]: if name is None: return None, None @@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ return bow_indices -def build_bows_one_hot_vectors(bow_indices): +def build_bows_one_hot_vectors(bow_indices, device='cuda'): if bow_indices is None: return None one_hot_bows_vectors = [] for single_bow in bow_indices: single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) - single_bow = torch.tensor(single_bow).cuda() + single_bow = torch.tensor(single_bow).to(device) num_words = single_bow.shape[0] - one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).cuda() + one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device) one_hot_bow.scatter_(1, single_bow, 1) one_hot_bows_vectors.append(one_hot_bow) return one_hot_bows_vectors @@ -425,7 +427,8 @@ def full_text_generation( length=length, perturb=False ) - torch.cuda.empty_cache() + if device == 'cuda': + torch.cuda.empty_cache() pert_gen_tok_texts = [] discrim_losses = [] @@ -460,7 +463,8 @@ def full_text_generation( discrim_losses.append(discrim_loss.data.cpu().numpy()) losses_in_time.append(loss_in_time) - torch.cuda.empty_cache() + if device == 'cuda': + torch.cuda.empty_cache() return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time @@ -496,7 +500,7 @@ def generate_text_pplm( ) # collect one hot vectors for bags of words - one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) + one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device) grad_norms = None last = None @@ -563,7 +567,7 @@ def generate_text_pplm( if classifier is not None: ce_loss = torch.nn.CrossEntropyLoss() prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) - label = torch.tensor([label_class], device='cuda', + label = torch.tensor([label_class], device=device, dtype=torch.long) unpert_discrim_loss = ce_loss(prediction, label) print(