From 75904dae669249c9f5d4d4d57890fb6c537d1639 Mon Sep 17 00:00:00 2001 From: w4nderlust Date: Fri, 29 Nov 2019 18:51:27 -0800 Subject: [PATCH] Removed global variable device --- examples/run_pplm_discrim_train.py | 47 ++++++++++++++++++------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/run_pplm_discrim_train.py b/examples/run_pplm_discrim_train.py index fccfb14426..db081e1a17 100644 --- a/examples/run_pplm_discrim_train.py +++ b/examples/run_pplm_discrim_train.py @@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel torch.manual_seed(0) np.random.seed(0) EPSILON = 1e-10 -device = "cpu" example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." max_length_seq = 100 @@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module): self, class_size, pretrained_model="gpt2-medium", - cached_mode=False + cached_mode=False, + device='cpu' ): super(Discriminator, self).__init__() self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) @@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module): embed_size=self.embed_size ) self.cached_mode = cached_mode + self.device = device def get_classifier(self): return self.classifier_head @@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module): def avg_representation(self, x): mask = x.ne(0).unsqueeze(2).repeat( 1, 1, self.embed_size - ).float().to(device).detach() + ).float().to(self.device).detach() hidden, _ = self.encoder.transformer(x) masked_hidden = hidden * mask avg_hidden = torch.sum(masked_hidden, dim=1) / ( @@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module): def forward(self, x): if self.cached_mode: - avg_hidden = x.to(device) + avg_hidden = x.to(self.device) else: - avg_hidden = self.avg_representation(x.to(device)) + avg_hidden = self.avg_representation(x.to(self.device)) logits = self.classifier_head(avg_hidden) probs = F.log_softmax(logits, dim=-1) @@ -152,7 +153,7 @@ def cached_collate_fn(data): def train_epoch(data_loader, discriminator, optimizer, - epoch=0, log_interval=10): + epoch=0, log_interval=10, device='cpu'): samples_so_far = 0 discriminator.train_custom() for batch_idx, (input_t, target_t) in enumerate(data_loader): @@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer, ) -def evaluate_performance(data_loader, discriminator): +def evaluate_performance(data_loader, discriminator, device='cpu'): discriminator.eval() test_loss = 0 correct = 0 @@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator): ) -def predict(input_sentence, model, classes, cached=False): +def predict(input_sentence, model, classes, cached=False, device='cpu'): input_t = model.tokenizer.encode(input_sentence) input_t = torch.tensor([input_t], dtype=torch.long, device=device) if cached: @@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False): )) -def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): +def get_cached_data_loader(dataset, batch_size, discriminator, + shuffle=False, device='cpu'): data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn) @@ -244,7 +246,6 @@ def train_discriminator( dataset, dataset_fp=None, pretrained_model="gpt2-medium", epochs=10, batch_size=64, log_interval=10, save_model=False, cached=False, no_cuda=False): - global device device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" print("Preprocessing {} dataset...".format(dataset)) @@ -258,7 +259,8 @@ def train_discriminator( discriminator = Discriminator( class_size=len(idx2class), pretrained_model=pretrained_model, - cached_mode=cached + cached_mode=cached, + device=device ).to(device) text = torchtext_data.Field() @@ -309,7 +311,8 @@ def train_discriminator( discriminator = Discriminator( class_size=len(idx2class), pretrained_model=pretrained_model, - cached_mode=cached + cached_mode=cached, + device=device ).to(device) with open("datasets/clickbait/clickbait_train_prefix.txt") as f: @@ -368,7 +371,8 @@ def train_discriminator( discriminator = Discriminator( class_size=len(idx2class), pretrained_model=pretrained_model, - cached_mode=cached + cached_mode=cached, + device=device ).to(device) x = [] @@ -431,7 +435,8 @@ def train_discriminator( discriminator = Discriminator( class_size=len(idx2class), pretrained_model=pretrained_model, - cached_mode=cached + cached_mode=cached, + device=device ).to(device) x = [] @@ -494,11 +499,12 @@ def train_discriminator( start = time.time() train_loader = get_cached_data_loader( - train_dataset, batch_size, discriminator, shuffle=True + train_dataset, batch_size, discriminator, + shuffle=True, device=device ) test_loader = get_cached_data_loader( - test_dataset, batch_size, discriminator + test_dataset, batch_size, discriminator, device=device ) end = time.time() @@ -529,18 +535,21 @@ def train_discriminator( data_loader=train_loader, optimizer=optimizer, epoch=epoch, - log_interval=log_interval + log_interval=log_interval, + device=device ) evaluate_performance( data_loader=test_loader, - discriminator=discriminator + discriminator=discriminator, + device=device ) end = time.time() print("Epoch took: {:.3f}s".format(end - start)) print("\nExample prediction") - predict(example_sentence, discriminator, idx2class, cached) + predict(example_sentence, discriminator, idx2class, + cached=cached, device=device) if save_model: # torch.save(discriminator.state_dict(),