Removed global variable device
This commit is contained in:
committed by
Julien Chaumond
parent
7fd54b55a3
commit
75904dae66
@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
EPSILON = 1e-10
|
||||
device = "cpu"
|
||||
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
||||
max_length_seq = 100
|
||||
|
||||
@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module):
|
||||
self,
|
||||
class_size,
|
||||
pretrained_model="gpt2-medium",
|
||||
cached_mode=False
|
||||
cached_mode=False,
|
||||
device='cpu'
|
||||
):
|
||||
super(Discriminator, self).__init__()
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||
@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module):
|
||||
embed_size=self.embed_size
|
||||
)
|
||||
self.cached_mode = cached_mode
|
||||
self.device = device
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier_head
|
||||
@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module):
|
||||
def avg_representation(self, x):
|
||||
mask = x.ne(0).unsqueeze(2).repeat(
|
||||
1, 1, self.embed_size
|
||||
).float().to(device).detach()
|
||||
).float().to(self.device).detach()
|
||||
hidden, _ = self.encoder.transformer(x)
|
||||
masked_hidden = hidden * mask
|
||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
||||
@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if self.cached_mode:
|
||||
avg_hidden = x.to(device)
|
||||
avg_hidden = x.to(self.device)
|
||||
else:
|
||||
avg_hidden = self.avg_representation(x.to(device))
|
||||
avg_hidden = self.avg_representation(x.to(self.device))
|
||||
|
||||
logits = self.classifier_head(avg_hidden)
|
||||
probs = F.log_softmax(logits, dim=-1)
|
||||
@@ -152,7 +153,7 @@ def cached_collate_fn(data):
|
||||
|
||||
|
||||
def train_epoch(data_loader, discriminator, optimizer,
|
||||
epoch=0, log_interval=10):
|
||||
epoch=0, log_interval=10, device='cpu'):
|
||||
samples_so_far = 0
|
||||
discriminator.train_custom()
|
||||
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||
@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_performance(data_loader, discriminator):
|
||||
def evaluate_performance(data_loader, discriminator, device='cpu'):
|
||||
discriminator.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator):
|
||||
)
|
||||
|
||||
|
||||
def predict(input_sentence, model, classes, cached=False):
|
||||
def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
||||
input_t = model.tokenizer.encode(input_sentence)
|
||||
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||
if cached:
|
||||
@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False):
|
||||
))
|
||||
|
||||
|
||||
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
||||
def get_cached_data_loader(dataset, batch_size, discriminator,
|
||||
shuffle=False, device='cpu'):
|
||||
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn)
|
||||
@@ -244,7 +246,6 @@ def train_discriminator(
|
||||
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
||||
epochs=10, batch_size=64, log_interval=10,
|
||||
save_model=False, cached=False, no_cuda=False):
|
||||
global device
|
||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||
|
||||
print("Preprocessing {} dataset...".format(dataset))
|
||||
@@ -258,7 +259,8 @@ def train_discriminator(
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
text = torchtext_data.Field()
|
||||
@@ -309,7 +311,8 @@ def train_discriminator(
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||
@@ -368,7 +371,8 @@ def train_discriminator(
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
@@ -431,7 +435,8 @@ def train_discriminator(
|
||||
discriminator = Discriminator(
|
||||
class_size=len(idx2class),
|
||||
pretrained_model=pretrained_model,
|
||||
cached_mode=cached
|
||||
cached_mode=cached,
|
||||
device=device
|
||||
).to(device)
|
||||
|
||||
x = []
|
||||
@@ -494,11 +499,12 @@ def train_discriminator(
|
||||
start = time.time()
|
||||
|
||||
train_loader = get_cached_data_loader(
|
||||
train_dataset, batch_size, discriminator, shuffle=True
|
||||
train_dataset, batch_size, discriminator,
|
||||
shuffle=True, device=device
|
||||
)
|
||||
|
||||
test_loader = get_cached_data_loader(
|
||||
test_dataset, batch_size, discriminator
|
||||
test_dataset, batch_size, discriminator, device=device
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
@@ -529,18 +535,21 @@ def train_discriminator(
|
||||
data_loader=train_loader,
|
||||
optimizer=optimizer,
|
||||
epoch=epoch,
|
||||
log_interval=log_interval
|
||||
log_interval=log_interval,
|
||||
device=device
|
||||
)
|
||||
evaluate_performance(
|
||||
data_loader=test_loader,
|
||||
discriminator=discriminator
|
||||
discriminator=discriminator,
|
||||
device=device
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
print("Epoch took: {:.3f}s".format(end - start))
|
||||
|
||||
print("\nExample prediction")
|
||||
predict(example_sentence, discriminator, idx2class, cached)
|
||||
predict(example_sentence, discriminator, idx2class,
|
||||
cached=cached, device=device)
|
||||
|
||||
if save_model:
|
||||
# torch.save(discriminator.state_dict(),
|
||||
|
||||
Reference in New Issue
Block a user