Removed global variable device
This commit is contained in:
committed by
Julien Chaumond
parent
7fd54b55a3
commit
75904dae66
@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
EPSILON = 1e-10
|
EPSILON = 1e-10
|
||||||
device = "cpu"
|
|
||||||
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
|
||||||
max_length_seq = 100
|
max_length_seq = 100
|
||||||
|
|
||||||
@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
class_size,
|
class_size,
|
||||||
pretrained_model="gpt2-medium",
|
pretrained_model="gpt2-medium",
|
||||||
cached_mode=False
|
cached_mode=False,
|
||||||
|
device='cpu'
|
||||||
):
|
):
|
||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
|
||||||
@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module):
|
|||||||
embed_size=self.embed_size
|
embed_size=self.embed_size
|
||||||
)
|
)
|
||||||
self.cached_mode = cached_mode
|
self.cached_mode = cached_mode
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.classifier_head
|
return self.classifier_head
|
||||||
@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module):
|
|||||||
def avg_representation(self, x):
|
def avg_representation(self, x):
|
||||||
mask = x.ne(0).unsqueeze(2).repeat(
|
mask = x.ne(0).unsqueeze(2).repeat(
|
||||||
1, 1, self.embed_size
|
1, 1, self.embed_size
|
||||||
).float().to(device).detach()
|
).float().to(self.device).detach()
|
||||||
hidden, _ = self.encoder.transformer(x)
|
hidden, _ = self.encoder.transformer(x)
|
||||||
masked_hidden = hidden * mask
|
masked_hidden = hidden * mask
|
||||||
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
avg_hidden = torch.sum(masked_hidden, dim=1) / (
|
||||||
@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.cached_mode:
|
if self.cached_mode:
|
||||||
avg_hidden = x.to(device)
|
avg_hidden = x.to(self.device)
|
||||||
else:
|
else:
|
||||||
avg_hidden = self.avg_representation(x.to(device))
|
avg_hidden = self.avg_representation(x.to(self.device))
|
||||||
|
|
||||||
logits = self.classifier_head(avg_hidden)
|
logits = self.classifier_head(avg_hidden)
|
||||||
probs = F.log_softmax(logits, dim=-1)
|
probs = F.log_softmax(logits, dim=-1)
|
||||||
@@ -152,7 +153,7 @@ def cached_collate_fn(data):
|
|||||||
|
|
||||||
|
|
||||||
def train_epoch(data_loader, discriminator, optimizer,
|
def train_epoch(data_loader, discriminator, optimizer,
|
||||||
epoch=0, log_interval=10):
|
epoch=0, log_interval=10, device='cpu'):
|
||||||
samples_so_far = 0
|
samples_so_far = 0
|
||||||
discriminator.train_custom()
|
discriminator.train_custom()
|
||||||
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
for batch_idx, (input_t, target_t) in enumerate(data_loader):
|
||||||
@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_performance(data_loader, discriminator):
|
def evaluate_performance(data_loader, discriminator, device='cpu'):
|
||||||
discriminator.eval()
|
discriminator.eval()
|
||||||
test_loss = 0
|
test_loss = 0
|
||||||
correct = 0
|
correct = 0
|
||||||
@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def predict(input_sentence, model, classes, cached=False):
|
def predict(input_sentence, model, classes, cached=False, device='cpu'):
|
||||||
input_t = model.tokenizer.encode(input_sentence)
|
input_t = model.tokenizer.encode(input_sentence)
|
||||||
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
|
||||||
if cached:
|
if cached:
|
||||||
@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False):
|
|||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
|
def get_cached_data_loader(dataset, batch_size, discriminator,
|
||||||
|
shuffle=False, device='cpu'):
|
||||||
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
collate_fn=collate_fn)
|
collate_fn=collate_fn)
|
||||||
@@ -244,7 +246,6 @@ def train_discriminator(
|
|||||||
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
|
||||||
epochs=10, batch_size=64, log_interval=10,
|
epochs=10, batch_size=64, log_interval=10,
|
||||||
save_model=False, cached=False, no_cuda=False):
|
save_model=False, cached=False, no_cuda=False):
|
||||||
global device
|
|
||||||
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
|
||||||
|
|
||||||
print("Preprocessing {} dataset...".format(dataset))
|
print("Preprocessing {} dataset...".format(dataset))
|
||||||
@@ -258,7 +259,8 @@ def train_discriminator(
|
|||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class),
|
||||||
pretrained_model=pretrained_model,
|
pretrained_model=pretrained_model,
|
||||||
cached_mode=cached
|
cached_mode=cached,
|
||||||
|
device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
text = torchtext_data.Field()
|
text = torchtext_data.Field()
|
||||||
@@ -309,7 +311,8 @@ def train_discriminator(
|
|||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class),
|
||||||
pretrained_model=pretrained_model,
|
pretrained_model=pretrained_model,
|
||||||
cached_mode=cached
|
cached_mode=cached,
|
||||||
|
device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
|
||||||
@@ -368,7 +371,8 @@ def train_discriminator(
|
|||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class),
|
||||||
pretrained_model=pretrained_model,
|
pretrained_model=pretrained_model,
|
||||||
cached_mode=cached
|
cached_mode=cached,
|
||||||
|
device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
x = []
|
x = []
|
||||||
@@ -431,7 +435,8 @@ def train_discriminator(
|
|||||||
discriminator = Discriminator(
|
discriminator = Discriminator(
|
||||||
class_size=len(idx2class),
|
class_size=len(idx2class),
|
||||||
pretrained_model=pretrained_model,
|
pretrained_model=pretrained_model,
|
||||||
cached_mode=cached
|
cached_mode=cached,
|
||||||
|
device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
x = []
|
x = []
|
||||||
@@ -494,11 +499,12 @@ def train_discriminator(
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
train_loader = get_cached_data_loader(
|
train_loader = get_cached_data_loader(
|
||||||
train_dataset, batch_size, discriminator, shuffle=True
|
train_dataset, batch_size, discriminator,
|
||||||
|
shuffle=True, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
test_loader = get_cached_data_loader(
|
test_loader = get_cached_data_loader(
|
||||||
test_dataset, batch_size, discriminator
|
test_dataset, batch_size, discriminator, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
@@ -529,18 +535,21 @@ def train_discriminator(
|
|||||||
data_loader=train_loader,
|
data_loader=train_loader,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
log_interval=log_interval
|
log_interval=log_interval,
|
||||||
|
device=device
|
||||||
)
|
)
|
||||||
evaluate_performance(
|
evaluate_performance(
|
||||||
data_loader=test_loader,
|
data_loader=test_loader,
|
||||||
discriminator=discriminator
|
discriminator=discriminator,
|
||||||
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("Epoch took: {:.3f}s".format(end - start))
|
print("Epoch took: {:.3f}s".format(end - start))
|
||||||
|
|
||||||
print("\nExample prediction")
|
print("\nExample prediction")
|
||||||
predict(example_sentence, discriminator, idx2class, cached)
|
predict(example_sentence, discriminator, idx2class,
|
||||||
|
cached=cached, device=device)
|
||||||
|
|
||||||
if save_model:
|
if save_model:
|
||||||
# torch.save(discriminator.state_dict(),
|
# torch.save(discriminator.state_dict(),
|
||||||
|
|||||||
Reference in New Issue
Block a user