[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -17,7 +17,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
@@ -270,6 +270,7 @@ class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
|
|||||||
|
|
||||||
from transformers import AlbertTokenizer
|
from transformers import AlbertTokenizer
|
||||||
from pabee import AlbertForSequenceClassificationWithPabee
|
from pabee import AlbertForSequenceClassificationWithPabee
|
||||||
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
|
|||||||
|
|
||||||
from transformers import BertTokenizer, BertForSequenceClassification
|
from transformers import BertTokenizer, BertForSequenceClassification
|
||||||
from pabee import BertForSequenceClassificationWithPabee
|
from pabee import BertForSequenceClassificationWithPabee
|
||||||
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -117,11 +118,11 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[args.local_rank],
|
device_ids=[args.local_rank],
|
||||||
output_device=args.local_rank,
|
output_device=args.local_rank,
|
||||||
@@ -203,9 +204,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
@@ -291,8 +292,8 @@ def evaluate(args, model, tokenizer, prefix="", patience=0):
|
|||||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -415,11 +416,11 @@ def main():
|
|||||||
# Distributed and parallel training
|
# Distributed and parallel training
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
)
|
)
|
||||||
elif args.n_gpu > 1:
|
elif args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Print/save training arguments
|
# Print/save training arguments
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -352,11 +353,11 @@ def main():
|
|||||||
# Distributed and parallel training
|
# Distributed and parallel training
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
)
|
)
|
||||||
elif args.n_gpu > 1:
|
elif args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Print/save training arguments
|
# Print/save training arguments
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -135,11 +136,11 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
|
|||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -190,9 +191,9 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
|
|||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
@@ -255,7 +256,7 @@ def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=Fa
|
|||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from transformers import RobertaConfig
|
from transformers import RobertaConfig
|
||||||
|
|||||||
@@ -21,8 +21,7 @@ import time
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
@@ -412,8 +411,8 @@ class Distiller:
|
|||||||
|
|
||||||
loss_ce = (
|
loss_ce = (
|
||||||
self.ce_loss_fct(
|
self.ce_loss_fct(
|
||||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||||
)
|
)
|
||||||
* (self.temperature) ** 2
|
* (self.temperature) ** 2
|
||||||
)
|
)
|
||||||
@@ -492,9 +491,9 @@ class Distiller:
|
|||||||
self.iter()
|
self.iter()
|
||||||
if self.n_iter % self.params.gradient_accumulation_steps == 0:
|
if self.n_iter % self.params.gradient_accumulation_steps == 0:
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ import timeit
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -138,11 +137,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,15 +231,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||||
loss_start = (
|
loss_start = (
|
||||||
loss_fct(
|
loss_fct(
|
||||||
F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||||
F.softmax(start_logits_tea / args.temperature, dim=-1),
|
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||||
)
|
)
|
||||||
* (args.temperature ** 2)
|
* (args.temperature ** 2)
|
||||||
)
|
)
|
||||||
loss_end = (
|
loss_end = (
|
||||||
loss_fct(
|
loss_fct(
|
||||||
F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||||
F.softmax(end_logits_tea / args.temperature, dim=-1),
|
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||||
)
|
)
|
||||||
* (args.temperature ** 2)
|
* (args.temperature ** 2)
|
||||||
)
|
)
|
||||||
@@ -262,9 +261,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
@@ -326,8 +325,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
# multi-gpu evaluate
|
# multi-gpu evaluate
|
||||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import torch
|
|||||||
import torch.utils.checkpoint as checkpoint
|
import torch.utils.checkpoint as checkpoint
|
||||||
from elasticsearch import Elasticsearch # noqa: F401
|
from elasticsearch import Elasticsearch # noqa: F401
|
||||||
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset):
|
|||||||
return self.make_example(idx % self.data.num_rows)
|
return self.make_example(idx % self.data.num_rows)
|
||||||
|
|
||||||
|
|
||||||
class RetrievalQAEmbedder(torch.nn.Module):
|
class RetrievalQAEmbedder(nn.Module):
|
||||||
def __init__(self, sent_encoder, dim):
|
def __init__(self, sent_encoder, dim):
|
||||||
super(RetrievalQAEmbedder, self).__init__()
|
super(RetrievalQAEmbedder, self).__init__()
|
||||||
self.sent_encoder = sent_encoder
|
self.sent_encoder = sent_encoder
|
||||||
self.output_dim = 128
|
self.output_dim = 128
|
||||||
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
|
self.project_q = nn.Linear(dim, self.output_dim, bias=False)
|
||||||
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
|
self.project_a = nn.Linear(dim, self.output_dim, bias=False)
|
||||||
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
|
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
|
||||||
|
|
||||||
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
|
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
|
||||||
# reproduces BERT forward pass with checkpointing
|
# reproduces BERT forward pass with checkpointing
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from typing import Dict, List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||||
from torchvision.ops import RoIPool
|
from torchvision.ops import RoIPool
|
||||||
from torchvision.ops.boxes import batched_nms, nms
|
from torchvision.ops.boxes import batched_nms, nms
|
||||||
@@ -85,7 +84,7 @@ def pad_list_tensors(
|
|||||||
too_small = True
|
too_small = True
|
||||||
tensor_i = tensor_i.unsqueeze(-1)
|
tensor_i = tensor_i.unsqueeze(-1)
|
||||||
assert isinstance(tensor_i, torch.Tensor)
|
assert isinstance(tensor_i, torch.Tensor)
|
||||||
tensor_i = F.pad(
|
tensor_i = nn.functional.pad(
|
||||||
input=tensor_i,
|
input=tensor_i,
|
||||||
pad=(0, 0, 0, max_detections - preds_per_image[i]),
|
pad=(0, 0, 0, max_detections - preds_per_image[i]),
|
||||||
mode="constant",
|
mode="constant",
|
||||||
@@ -701,7 +700,7 @@ class RPNOutputs(object):
|
|||||||
|
|
||||||
|
|
||||||
# Main Classes
|
# Main Classes
|
||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(nn.Conv2d):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
norm = kwargs.pop("norm", None)
|
norm = kwargs.pop("norm", None)
|
||||||
activation = kwargs.pop("activation", None)
|
activation = kwargs.pop("activation", None)
|
||||||
@@ -712,9 +711,9 @@ class Conv2d(torch.nn.Conv2d):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if x.numel() == 0 and self.training:
|
if x.numel() == 0 and self.training:
|
||||||
assert not isinstance(self.norm, torch.nn.SyncBatchNorm)
|
assert not isinstance(self.norm, nn.SyncBatchNorm)
|
||||||
if x.numel() == 0:
|
if x.numel() == 0:
|
||||||
assert not isinstance(self.norm, torch.nn.GroupNorm)
|
assert not isinstance(self.norm, nn.GroupNorm)
|
||||||
output_shape = [
|
output_shape = [
|
||||||
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
|
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
|
||||||
for i, p, di, k, s in zip(
|
for i, p, di, k, s in zip(
|
||||||
@@ -752,7 +751,7 @@ class LastLevelMaxPool(nn.Module):
|
|||||||
self.in_feature = "p5"
|
self.in_feature = "p5"
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
|
return [nn.functional.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
|
||||||
|
|
||||||
|
|
||||||
class LastLevelP6P7(nn.Module):
|
class LastLevelP6P7(nn.Module):
|
||||||
@@ -769,7 +768,7 @@ class LastLevelP6P7(nn.Module):
|
|||||||
|
|
||||||
def forward(self, c5):
|
def forward(self, c5):
|
||||||
p6 = self.p6(c5)
|
p6 = self.p6(c5)
|
||||||
p7 = self.p7(F.relu(p6))
|
p7 = self.p7(nn.functional.relu(p6))
|
||||||
return [p6, p7]
|
return [p6, p7]
|
||||||
|
|
||||||
|
|
||||||
@@ -790,11 +789,11 @@ class BasicStem(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = F.relu_(x)
|
x = nn.functional.relu_(x)
|
||||||
if self.caffe_maxpool:
|
if self.caffe_maxpool:
|
||||||
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||||||
else:
|
else:
|
||||||
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -881,10 +880,10 @@ class BottleneckBlock(ResNetBlockBase):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
out = F.relu_(out)
|
out = nn.functional.relu_(out)
|
||||||
|
|
||||||
out = self.conv2(out)
|
out = self.conv2(out)
|
||||||
out = F.relu_(out)
|
out = nn.functional.relu_(out)
|
||||||
|
|
||||||
out = self.conv3(out)
|
out = self.conv3(out)
|
||||||
|
|
||||||
@@ -894,7 +893,7 @@ class BottleneckBlock(ResNetBlockBase):
|
|||||||
shortcut = x
|
shortcut = x
|
||||||
|
|
||||||
out += shortcut
|
out += shortcut
|
||||||
out = F.relu_(out)
|
out = nn.functional.relu_(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -1159,7 +1158,7 @@ class ROIOutputs(object):
|
|||||||
return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
|
return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
|
||||||
|
|
||||||
def _predict_objs(self, obj_logits, preds_per_image):
|
def _predict_objs(self, obj_logits, preds_per_image):
|
||||||
probs = F.softmax(obj_logits, dim=-1)
|
probs = nn.functional.softmax(obj_logits, dim=-1)
|
||||||
probs = probs.split(preds_per_image, dim=0)
|
probs = probs.split(preds_per_image, dim=0)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
@@ -1490,7 +1489,7 @@ class RPNHead(nn.Module):
|
|||||||
pred_objectness_logits = []
|
pred_objectness_logits = []
|
||||||
pred_anchor_deltas = []
|
pred_anchor_deltas = []
|
||||||
for x in features:
|
for x in features:
|
||||||
t = F.relu(self.conv(x))
|
t = nn.functional.relu(self.conv(x))
|
||||||
pred_objectness_logits.append(self.objectness_logits(t))
|
pred_objectness_logits.append(self.objectness_logits(t))
|
||||||
pred_anchor_deltas.append(self.anchor_deltas(t))
|
pred_anchor_deltas.append(self.anchor_deltas(t))
|
||||||
return pred_objectness_logits, pred_anchor_deltas
|
return pred_objectness_logits, pred_anchor_deltas
|
||||||
@@ -1650,7 +1649,7 @@ class FastRCNNOutputLayers(nn.Module):
|
|||||||
cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
|
cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
|
||||||
roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
|
roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
|
||||||
roi_features = self.fc_attr(roi_features)
|
roi_features = self.fc_attr(roi_features)
|
||||||
roi_features = F.relu(roi_features)
|
roi_features = nn.functional.relu(roi_features)
|
||||||
attr_scores = self.attr_score(roi_features)
|
attr_scores = self.attr_score(roi_features)
|
||||||
return scores, attr_scores, proposal_deltas
|
return scores, attr_scores, proposal_deltas
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ from typing import Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from utils import img_tensorize
|
from utils import img_tensorize
|
||||||
|
|
||||||
@@ -63,7 +63,9 @@ class ResizeShortestEdge:
|
|||||||
img = np.asarray(pil_image)
|
img = np.asarray(pil_image)
|
||||||
else:
|
else:
|
||||||
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
|
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
|
||||||
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
|
img = nn.functional.interpolate(
|
||||||
|
img, (newh, neww), mode=self.interp_method, align_corners=False
|
||||||
|
).squeeze(0)
|
||||||
img_augs.append(img)
|
img_augs.append(img)
|
||||||
|
|
||||||
return img_augs
|
return img_augs
|
||||||
@@ -85,7 +87,7 @@ class Preprocess:
|
|||||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||||
image_sizes = [im.shape[-2:] for im in images]
|
image_sizes = [im.shape[-2:] for im in images]
|
||||||
images = [
|
images = [
|
||||||
F.pad(
|
nn.functional.pad(
|
||||||
im,
|
im,
|
||||||
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
|
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
|
||||||
value=self.pad_value,
|
value=self.pad_value,
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from sklearn.metrics import f1_score
|
from sklearn.metrics import f1_score
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -107,11 +107,11 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,9 +166,9 @@ def train(args, train_dataset, model, tokenizer, criterion):
|
|||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
@@ -248,8 +248,8 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ import os
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@
|
|||||||
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
"quantized_model = torch.quantization.quantize_dynamic(\n",
|
||||||
" model=model,\n",
|
" model=model,\n",
|
||||||
" qconfig_spec = {\n",
|
" qconfig_spec = {\n",
|
||||||
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
|
" nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" dtype=torch.qint8,\n",
|
" dtype=torch.qint8,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
|
|
||||||
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
|
||||||
@@ -104,4 +103,4 @@ class MaskedLinear(nn.Linear):
|
|||||||
# Mask weights with computed mask
|
# Mask weights with computed mask
|
||||||
weight_thresholded = mask * self.weight
|
weight_thresholded = mask * self.weight
|
||||||
# Compute output (linear layer) with masked weights
|
# Compute output (linear layer) with masked weights
|
||||||
return F.linear(input, weight_thresholded, self.bias)
|
return nn.functional.linear(input, weight_thresholded, self.bias)
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -168,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[args.local_rank],
|
device_ids=[args.local_rank],
|
||||||
output_device=args.local_rank,
|
output_device=args.local_rank,
|
||||||
@@ -287,9 +286,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss_logits = (
|
loss_logits = (
|
||||||
F.kl_div(
|
nn.functional.kl_div(
|
||||||
input=F.log_softmax(logits_stu / args.temperature, dim=-1),
|
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||||
target=F.softmax(logits_tea / args.temperature, dim=-1),
|
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
)
|
)
|
||||||
* (args.temperature ** 2)
|
* (args.temperature ** 2)
|
||||||
@@ -320,9 +319,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
and (step + 1) == len(epoch_iterator)
|
and (step + 1) == len(epoch_iterator)
|
||||||
):
|
):
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
tb_writer.add_scalar("threshold", threshold, global_step)
|
tb_writer.add_scalar("threshold", threshold, global_step)
|
||||||
@@ -436,8 +435,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ import timeit
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -176,11 +175,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
|
|
||||||
# multi-gpu training (should be after apex fp16 initialization)
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if args.local_rank != -1:
|
if args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[args.local_rank],
|
device_ids=[args.local_rank],
|
||||||
output_device=args.local_rank,
|
output_device=args.local_rank,
|
||||||
@@ -308,17 +307,17 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss_start = (
|
loss_start = (
|
||||||
F.kl_div(
|
nn.functional.kl_div(
|
||||||
input=F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||||
target=F.softmax(start_logits_tea / args.temperature, dim=-1),
|
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
)
|
)
|
||||||
* (args.temperature ** 2)
|
* (args.temperature ** 2)
|
||||||
)
|
)
|
||||||
loss_end = (
|
loss_end = (
|
||||||
F.kl_div(
|
nn.functional.kl_div(
|
||||||
input=F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||||
target=F.softmax(end_logits_tea / args.temperature, dim=-1),
|
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||||
reduction="batchmean",
|
reduction="batchmean",
|
||||||
)
|
)
|
||||||
* (args.temperature ** 2)
|
* (args.temperature ** 2)
|
||||||
@@ -346,9 +345,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
tb_writer.add_scalar("threshold", threshold, global_step)
|
tb_writer.add_scalar("threshold", threshold, global_step)
|
||||||
@@ -454,8 +453,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
# multi-gpu eval
|
# multi-gpu eval
|
||||||
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
||||||
model = torch.nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
# Eval!
|
# Eval!
|
||||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
import torch
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class ClassificationHead(torch.nn.Module):
|
class ClassificationHead(nn.Module):
|
||||||
"""Classification Head for transformer encoders"""
|
"""Classification Head for transformer encoders"""
|
||||||
|
|
||||||
def __init__(self, class_size, embed_size):
|
def __init__(self, class_size, embed_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.class_size = class_size
|
self.class_size = class_size
|
||||||
self.embed_size = embed_size
|
self.embed_size = embed_size
|
||||||
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
|
# self.mlp1 = nn.Linear(embed_size, embed_size)
|
||||||
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
|
# self.mlp2 = (nn.Linear(embed_size, class_size))
|
||||||
self.mlp = torch.nn.Linear(embed_size, class_size)
|
self.mlp = nn.Linear(embed_size, class_size)
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
# hidden_state = F.relu(self.mlp1(hidden_state))
|
# hidden_state = nn.functional.relu(self.mlp1(hidden_state))
|
||||||
# hidden_state = self.mlp2(hidden_state)
|
# hidden_state = self.mlp2(hidden_state)
|
||||||
logits = self.mlp(hidden_state)
|
logits = self.mlp(hidden_state)
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
from torch import nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from pplm_classification_head import ClassificationHead
|
from pplm_classification_head import ClassificationHead
|
||||||
@@ -160,7 +160,7 @@ def perturb_past(
|
|||||||
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
|
||||||
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
|
||||||
logits = all_logits[:, -1, :]
|
logits = all_logits[:, -1, :]
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = nn.functional.softmax(logits, dim=-1)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
loss_list = []
|
loss_list = []
|
||||||
@@ -173,7 +173,7 @@ def perturb_past(
|
|||||||
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
print(" pplm_bow_loss:", loss.data.cpu().numpy())
|
||||||
|
|
||||||
if loss_type == 2 or loss_type == 3:
|
if loss_type == 2 or loss_type == 3:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = nn.CrossEntropyLoss()
|
||||||
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
|
||||||
curr_unpert_past = unpert_past
|
curr_unpert_past = unpert_past
|
||||||
curr_probs = torch.unsqueeze(probs, dim=1)
|
curr_probs = torch.unsqueeze(probs, dim=1)
|
||||||
@@ -195,7 +195,7 @@ def perturb_past(
|
|||||||
|
|
||||||
kl_loss = 0.0
|
kl_loss = 0.0
|
||||||
if kl_scale > 0.0:
|
if kl_scale > 0.0:
|
||||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||||
corrected_probs = probs + correction.detach()
|
corrected_probs = probs + correction.detach()
|
||||||
@@ -527,10 +527,10 @@ def generate_text_pplm(
|
|||||||
else:
|
else:
|
||||||
pert_logits[0, token_idx] /= repetition_penalty
|
pert_logits[0, token_idx] /= repetition_penalty
|
||||||
|
|
||||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
|
||||||
|
|
||||||
if classifier is not None:
|
if classifier is not None:
|
||||||
ce_loss = torch.nn.CrossEntropyLoss()
|
ce_loss = nn.CrossEntropyLoss()
|
||||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||||
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
label = torch.tensor([class_label], device=device, dtype=torch.long)
|
||||||
unpert_discrim_loss = ce_loss(prediction, label)
|
unpert_discrim_loss = ce_loss(prediction, label)
|
||||||
@@ -541,7 +541,7 @@ def generate_text_pplm(
|
|||||||
# Fuse the modified model and original model
|
# Fuse the modified model and original model
|
||||||
if perturb:
|
if perturb:
|
||||||
|
|
||||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||||
|
|
||||||
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||||
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||||
@@ -552,7 +552,7 @@ def generate_text_pplm(
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
|
||||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
pert_probs = nn.functional.softmax(pert_logits, dim=-1)
|
||||||
|
|
||||||
# sample or greedy
|
# sample or greedy
|
||||||
if sample:
|
if sample:
|
||||||
|
|||||||
@@ -23,10 +23,10 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||||
|
from torch import nn
|
||||||
from torchtext import data as torchtext_data
|
from torchtext import data as torchtext_data
|
||||||
from torchtext import datasets
|
from torchtext import datasets
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
|
|||||||
max_length_seq = 100
|
max_length_seq = 100
|
||||||
|
|
||||||
|
|
||||||
class Discriminator(torch.nn.Module):
|
class Discriminator(nn.Module):
|
||||||
"""Transformer encoder followed by a Classification Head"""
|
"""Transformer encoder followed by a Classification Head"""
|
||||||
|
|
||||||
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
|
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
|
||||||
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
|
|||||||
avg_hidden = self.avg_representation(x.to(self.device))
|
avg_hidden = self.avg_representation(x.to(self.device))
|
||||||
|
|
||||||
logits = self.classifier_head(avg_hidden)
|
logits = self.classifier_head(avg_hidden)
|
||||||
probs = F.log_softmax(logits, dim=-1)
|
probs = nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
output_t = discriminator(input_t)
|
output_t = discriminator(input_t)
|
||||||
loss = F.nll_loss(output_t, target_t)
|
loss = nn.functional.nll_loss(output_t, target_t)
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
|
|||||||
input_t, target_t = input_t.to(device), target_t.to(device)
|
input_t, target_t = input_t.to(device), target_t.to(device)
|
||||||
output_t = discriminator(input_t)
|
output_t = discriminator(input_t)
|
||||||
# sum up batch loss
|
# sum up batch loss
|
||||||
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
|
test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
|
||||||
# get the index of the max log-probability
|
# get the index of the max log-probability
|
||||||
pred_t = output_t.argmax(dim=1, keepdim=True)
|
pred_t = output_t.argmax(dim=1, keepdim=True)
|
||||||
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
import lightning_base
|
import lightning_base
|
||||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||||
@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
|
|||||||
|
|
||||||
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
||||||
|
|
||||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
lprobs = nn.functional.log_softmax(logits, dim=-1)
|
||||||
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
||||||
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from typing import List
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from finetune import SummarizationModule, TranslationModule
|
from finetune import SummarizationModule, TranslationModule
|
||||||
from finetune import main as ft_main
|
from finetune import main as ft_main
|
||||||
@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
assert t_logits_slct.size() == s_logits_slct.size()
|
assert t_logits_slct.size() == s_logits_slct.size()
|
||||||
loss_ce = (
|
loss_ce = (
|
||||||
self.ce_loss_fct(
|
self.ce_loss_fct(
|
||||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||||
)
|
)
|
||||||
* (self.temperature) ** 2
|
* (self.temperature) ** 2
|
||||||
)
|
)
|
||||||
@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
if self.hparams.label_smoothing == 0:
|
if self.hparams.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = F.log_softmax(lm_logits, dim=-1)
|
lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
student_lm_loss, _ = label_smoothed_nll_loss(
|
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||||
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
)
|
)
|
||||||
@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||||
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
|
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
|
||||||
if normalize_hidden:
|
if normalize_hidden:
|
||||||
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
student_states = nn.functional.layer_norm(student_states, student_states.shape[1:])
|
||||||
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
teacher_states = nn.functional.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||||
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
mse = nn.functional.mse_loss(student_states, teacher_states, reduction="none")
|
||||||
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
|
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
|
||||||
return masked_mse
|
return masked_mse
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
lm_logits = outputs["logits"]
|
lm_logits = outputs["logits"]
|
||||||
if self.hparams.label_smoothing == 0:
|
if self.hparams.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
|
|
||||||
assert lm_logits.shape[-1] == self.vocab_size
|
assert lm_logits.shape[-1] == self.vocab_size
|
||||||
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
loss, nll_loss = label_smoothed_nll_loss(
|
loss, nll_loss = label_smoothed_nll_loss(
|
||||||
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
from lang_trans import arabic
|
from lang_trans import arabic
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
|||||||
Reference in New Issue
Block a user