[style] consistent nn. and nn.functional: part 4 examples (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:28:24 -07:00
committed by GitHub
parent 372ab9cd6d
commit 88e84186e5
26 changed files with 130 additions and 126 deletions

View File

@@ -17,7 +17,7 @@
import logging import logging
import torch import torch
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
@@ -270,6 +270,7 @@ class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
from transformers import AlbertTokenizer from transformers import AlbertTokenizer
from pabee import AlbertForSequenceClassificationWithPabee from pabee import AlbertForSequenceClassificationWithPabee
from torch import nn
import torch import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')

View File

@@ -294,6 +294,7 @@ class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
from transformers import BertTokenizer, BertForSequenceClassification from transformers import BertTokenizer, BertForSequenceClassification
from pabee import BertForSequenceClassificationWithPabee from pabee import BertForSequenceClassificationWithPabee
from torch import nn
import torch import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

View File

@@ -25,6 +25,7 @@ import random
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -117,11 +118,11 @@ def train(args, train_dataset, model, tokenizer):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[args.local_rank], device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
@@ -203,9 +204,9 @@ def train(args, train_dataset, model, tokenizer):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
@@ -291,8 +292,8 @@ def evaluate(args, model, tokenizer, prefix="", patience=0):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))

View File

@@ -26,6 +26,7 @@ from datetime import datetime
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler, Subset from torch.utils.data import DataLoader, SequentialSampler, Subset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
@@ -415,11 +416,11 @@ def main():
# Distributed and parallel training # Distributed and parallel training
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
elif args.n_gpu > 1: elif args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Print/save training arguments # Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)

View File

@@ -10,6 +10,7 @@ from datetime import datetime
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from tqdm import tqdm from tqdm import tqdm
@@ -352,11 +353,11 @@ def main():
# Distributed and parallel training # Distributed and parallel training
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
elif args.n_gpu > 1: elif args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Print/save training arguments # Print/save training arguments
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)

View File

@@ -9,6 +9,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -135,11 +136,11 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
@@ -190,9 +191,9 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
@@ -255,7 +256,7 @@ def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=Fa
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))

View File

@@ -1,6 +1,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from transformers import RobertaConfig from transformers import RobertaConfig

View File

@@ -21,8 +21,7 @@ import time
import psutil import psutil
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import BatchSampler, DataLoader, RandomSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@@ -412,8 +411,8 @@ class Distiller:
loss_ce = ( loss_ce = (
self.ce_loss_fct( self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1), nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1), nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
) )
* (self.temperature) ** 2 * (self.temperature) ** 2
) )
@@ -492,9 +491,9 @@ class Distiller:
self.iter() self.iter()
if self.n_iter % self.params.gradient_accumulation_steps == 0: if self.n_iter % self.params.gradient_accumulation_steps == 0:
if self.fp16: if self.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm)
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() self.scheduler.step()

View File

@@ -24,8 +24,7 @@ import timeit
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -138,11 +137,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
@@ -232,15 +231,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
loss_fct = nn.KLDivLoss(reduction="batchmean") loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_start = ( loss_start = (
loss_fct( loss_fct(
F.log_softmax(start_logits_stu / args.temperature, dim=-1), nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
F.softmax(start_logits_tea / args.temperature, dim=-1), nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
) )
* (args.temperature ** 2) * (args.temperature ** 2)
) )
loss_end = ( loss_end = (
loss_fct( loss_fct(
F.log_softmax(end_logits_stu / args.temperature, dim=-1), nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
F.softmax(end_logits_tea / args.temperature, dim=-1), nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
) )
* (args.temperature ** 2) * (args.temperature ** 2)
) )
@@ -262,9 +261,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
@@ -326,8 +325,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate # multi-gpu evaluate
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))

View File

@@ -11,6 +11,7 @@ import torch
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from elasticsearch import Elasticsearch # noqa: F401 from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm from tqdm import tqdm
@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset):
return self.make_example(idx % self.data.num_rows) return self.make_example(idx % self.data.num_rows)
class RetrievalQAEmbedder(torch.nn.Module): class RetrievalQAEmbedder(nn.Module):
def __init__(self, sent_encoder, dim): def __init__(self, sent_encoder, dim):
super(RetrievalQAEmbedder, self).__init__() super(RetrievalQAEmbedder, self).__init__()
self.sent_encoder = sent_encoder self.sent_encoder = sent_encoder
self.output_dim = 128 self.output_dim = 128
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False) self.project_q = nn.Linear(dim, self.output_dim, bias=False)
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False) self.project_a = nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean") self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1): def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing # reproduces BERT forward pass with checkpointing

View File

@@ -25,7 +25,6 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops import RoIPool from torchvision.ops import RoIPool
from torchvision.ops.boxes import batched_nms, nms from torchvision.ops.boxes import batched_nms, nms
@@ -85,7 +84,7 @@ def pad_list_tensors(
too_small = True too_small = True
tensor_i = tensor_i.unsqueeze(-1) tensor_i = tensor_i.unsqueeze(-1)
assert isinstance(tensor_i, torch.Tensor) assert isinstance(tensor_i, torch.Tensor)
tensor_i = F.pad( tensor_i = nn.functional.pad(
input=tensor_i, input=tensor_i,
pad=(0, 0, 0, max_detections - preds_per_image[i]), pad=(0, 0, 0, max_detections - preds_per_image[i]),
mode="constant", mode="constant",
@@ -701,7 +700,7 @@ class RPNOutputs(object):
# Main Classes # Main Classes
class Conv2d(torch.nn.Conv2d): class Conv2d(nn.Conv2d):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
norm = kwargs.pop("norm", None) norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None) activation = kwargs.pop("activation", None)
@@ -712,9 +711,9 @@ class Conv2d(torch.nn.Conv2d):
def forward(self, x): def forward(self, x):
if x.numel() == 0 and self.training: if x.numel() == 0 and self.training:
assert not isinstance(self.norm, torch.nn.SyncBatchNorm) assert not isinstance(self.norm, nn.SyncBatchNorm)
if x.numel() == 0: if x.numel() == 0:
assert not isinstance(self.norm, torch.nn.GroupNorm) assert not isinstance(self.norm, nn.GroupNorm)
output_shape = [ output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1 (i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip( for i, p, di, k, s in zip(
@@ -752,7 +751,7 @@ class LastLevelMaxPool(nn.Module):
self.in_feature = "p5" self.in_feature = "p5"
def forward(self, x): def forward(self, x):
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] return [nn.functional.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
class LastLevelP6P7(nn.Module): class LastLevelP6P7(nn.Module):
@@ -769,7 +768,7 @@ class LastLevelP6P7(nn.Module):
def forward(self, c5): def forward(self, c5):
p6 = self.p6(c5) p6 = self.p6(c5)
p7 = self.p7(F.relu(p6)) p7 = self.p7(nn.functional.relu(p6))
return [p6, p7] return [p6, p7]
@@ -790,11 +789,11 @@ class BasicStem(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = F.relu_(x) x = nn.functional.relu_(x)
if self.caffe_maxpool: if self.caffe_maxpool:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True) x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
else: else:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x return x
@property @property
@@ -881,10 +880,10 @@ class BottleneckBlock(ResNetBlockBase):
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
out = F.relu_(out) out = nn.functional.relu_(out)
out = self.conv2(out) out = self.conv2(out)
out = F.relu_(out) out = nn.functional.relu_(out)
out = self.conv3(out) out = self.conv3(out)
@@ -894,7 +893,7 @@ class BottleneckBlock(ResNetBlockBase):
shortcut = x shortcut = x
out += shortcut out += shortcut
out = F.relu_(out) out = nn.functional.relu_(out)
return out return out
@@ -1159,7 +1158,7 @@ class ROIOutputs(object):
return boxes.view(num_pred, K * B).split(preds_per_image, dim=0) return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
def _predict_objs(self, obj_logits, preds_per_image): def _predict_objs(self, obj_logits, preds_per_image):
probs = F.softmax(obj_logits, dim=-1) probs = nn.functional.softmax(obj_logits, dim=-1)
probs = probs.split(preds_per_image, dim=0) probs = probs.split(preds_per_image, dim=0)
return probs return probs
@@ -1490,7 +1489,7 @@ class RPNHead(nn.Module):
pred_objectness_logits = [] pred_objectness_logits = []
pred_anchor_deltas = [] pred_anchor_deltas = []
for x in features: for x in features:
t = F.relu(self.conv(x)) t = nn.functional.relu(self.conv(x))
pred_objectness_logits.append(self.objectness_logits(t)) pred_objectness_logits.append(self.objectness_logits(t))
pred_anchor_deltas.append(self.anchor_deltas(t)) pred_anchor_deltas.append(self.anchor_deltas(t))
return pred_objectness_logits, pred_anchor_deltas return pred_objectness_logits, pred_anchor_deltas
@@ -1650,7 +1649,7 @@ class FastRCNNOutputLayers(nn.Module):
cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256] cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304] roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
roi_features = self.fc_attr(roi_features) roi_features = self.fc_attr(roi_features)
roi_features = F.relu(roi_features) roi_features = nn.functional.relu(roi_features)
attr_scores = self.attr_score(roi_features) attr_scores = self.attr_score(roi_features)
return scores, attr_scores, proposal_deltas return scores, attr_scores, proposal_deltas
else: else:

View File

@@ -20,8 +20,8 @@ from typing import Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
from torch import nn
from utils import img_tensorize from utils import img_tensorize
@@ -63,7 +63,9 @@ class ResizeShortestEdge:
img = np.asarray(pil_image) img = np.asarray(pil_image)
else: else:
img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0) img = nn.functional.interpolate(
img, (newh, neww), mode=self.interp_method, align_corners=False
).squeeze(0)
img_augs.append(img) img_augs.append(img)
return img_augs return img_augs
@@ -85,7 +87,7 @@ class Preprocess:
max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
image_sizes = [im.shape[-2:] for im in images] image_sizes = [im.shape[-2:] for im in images]
images = [ images = [
F.pad( nn.functional.pad(
im, im,
[0, max_size[-1] - size[1], 0, max_size[-2] - size[0]], [0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
value=self.pad_value, value=self.pad_value,

View File

@@ -25,8 +25,8 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -107,11 +107,11 @@ def train(args, train_dataset, model, tokenizer, criterion):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
) )
@@ -166,9 +166,9 @@ def train(args, train_dataset, model, tokenizer, criterion):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
@@ -248,8 +248,8 @@ def evaluate(args, model, tokenizer, criterion, prefix=""):
) )
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))

View File

@@ -19,10 +19,10 @@ import os
from collections import Counter from collections import Counter
import torch import torch
import torch.nn as nn
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from PIL import Image from PIL import Image
from torch import nn
from torch.utils.data import Dataset from torch.utils.data import Dataset

View File

@@ -75,7 +75,7 @@
"quantized_model = torch.quantization.quantize_dynamic(\n", "quantized_model = torch.quantization.quantize_dynamic(\n",
" model=model,\n", " model=model,\n",
" qconfig_spec = {\n", " qconfig_spec = {\n",
" torch.nn.Linear : torch.quantization.default_dynamic_qconfig,\n", " nn.Linear : torch.quantization.default_dynamic_qconfig,\n",
" },\n", " },\n",
" dtype=torch.qint8,\n", " dtype=torch.qint8,\n",
" )\n", " )\n",

View File

@@ -23,7 +23,6 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn import init from torch.nn import init
from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
@@ -104,4 +103,4 @@ class MaskedLinear(nn.Linear):
# Mask weights with computed mask # Mask weights with computed mask
weight_thresholded = mask * self.weight weight_thresholded = mask * self.weight
# Compute output (linear layer) with masked weights # Compute output (linear layer) with masked weights
return F.linear(input, weight_thresholded, self.bias) return nn.functional.linear(input, weight_thresholded, self.bias)

View File

@@ -24,8 +24,7 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -168,11 +167,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[args.local_rank], device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
@@ -287,9 +286,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
) )
loss_logits = ( loss_logits = (
F.kl_div( nn.functional.kl_div(
input=F.log_softmax(logits_stu / args.temperature, dim=-1), input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
target=F.softmax(logits_tea / args.temperature, dim=-1), target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
reduction="batchmean", reduction="batchmean",
) )
* (args.temperature ** 2) * (args.temperature ** 2)
@@ -320,9 +319,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
and (step + 1) == len(epoch_iterator) and (step + 1) == len(epoch_iterator)
): ):
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step) tb_writer.add_scalar("threshold", threshold, global_step)
@@ -436,8 +435,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))

View File

@@ -25,8 +25,7 @@ import timeit
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -176,11 +175,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
model, model,
device_ids=[args.local_rank], device_ids=[args.local_rank],
output_device=args.local_rank, output_device=args.local_rank,
@@ -308,17 +307,17 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
) )
loss_start = ( loss_start = (
F.kl_div( nn.functional.kl_div(
input=F.log_softmax(start_logits_stu / args.temperature, dim=-1), input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
target=F.softmax(start_logits_tea / args.temperature, dim=-1), target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
reduction="batchmean", reduction="batchmean",
) )
* (args.temperature ** 2) * (args.temperature ** 2)
) )
loss_end = ( loss_end = (
F.kl_div( nn.functional.kl_div(
input=F.log_softmax(end_logits_stu / args.temperature, dim=-1), input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
target=F.softmax(end_logits_tea / args.temperature, dim=-1), target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
reduction="batchmean", reduction="batchmean",
) )
* (args.temperature ** 2) * (args.temperature ** 2)
@@ -346,9 +345,9 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("threshold", threshold, global_step) tb_writer.add_scalar("threshold", threshold, global_step)
@@ -454,8 +453,8 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval # multi-gpu eval
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model) model = nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))

View File

@@ -1,19 +1,19 @@
import torch from torch import nn
class ClassificationHead(torch.nn.Module): class ClassificationHead(nn.Module):
"""Classification Head for transformer encoders""" """Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size): def __init__(self, class_size, embed_size):
super().__init__() super().__init__()
self.class_size = class_size self.class_size = class_size
self.embed_size = embed_size self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size) # self.mlp1 = nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size)) # self.mlp2 = (nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size) self.mlp = nn.Linear(embed_size, class_size)
def forward(self, hidden_state): def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state)) # hidden_state = nn.functional.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state) # hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state) logits = self.mlp(hidden_state)
return logits return logits

View File

@@ -30,7 +30,7 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F from torch import nn
from tqdm import trange from tqdm import trange
from pplm_classification_head import ClassificationHead from pplm_classification_head import ClassificationHead
@@ -160,7 +160,7 @@ def perturb_past(
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach() new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :] logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1) probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0 loss = 0.0
loss_list = [] loss_list = []
@@ -173,7 +173,7 @@ def perturb_past(
print(" pplm_bow_loss:", loss.data.cpu().numpy()) print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3: if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth) # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1) curr_probs = torch.unsqueeze(probs, dim=1)
@@ -195,7 +195,7 @@ def perturb_past(
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach() unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach() corrected_probs = probs + correction.detach()
@@ -527,10 +527,10 @@ def generate_text_pplm(
else: else:
pert_logits[0, token_idx] /= repetition_penalty pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1) pert_probs = nn.functional.softmax(pert_logits, dim=-1)
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device, dtype=torch.long) label = torch.tensor([class_label], device=device, dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label) unpert_discrim_loss = ce_loss(prediction, label)
@@ -541,7 +541,7 @@ def generate_text_pplm(
# Fuse the modified model and original model # Fuse the modified model and original model
if perturb: if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
@@ -552,7 +552,7 @@ def generate_text_pplm(
else: else:
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1) pert_probs = nn.functional.softmax(pert_logits, dim=-1)
# sample or greedy # sample or greedy
if sample: if sample:

View File

@@ -23,10 +23,10 @@ import time
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.utils.data as data import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer from nltk.tokenize.treebank import TreebankWordDetokenizer
from torch import nn
from torchtext import data as torchtext_data from torchtext import data as torchtext_data
from torchtext import datasets from torchtext import datasets
from tqdm import tqdm, trange from tqdm import tqdm, trange
@@ -42,7 +42,7 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100 max_length_seq = 100
class Discriminator(torch.nn.Module): class Discriminator(nn.Module):
"""Transformer encoder followed by a Classification Head""" """Transformer encoder followed by a Classification Head"""
def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"): def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
@@ -76,7 +76,7 @@ class Discriminator(torch.nn.Module):
avg_hidden = self.avg_representation(x.to(self.device)) avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden) logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1) probs = nn.functional.log_softmax(logits, dim=-1)
return probs return probs
@@ -140,7 +140,7 @@ def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10,
optimizer.zero_grad() optimizer.zero_grad()
output_t = discriminator(input_t) output_t = discriminator(input_t)
loss = F.nll_loss(output_t, target_t) loss = nn.functional.nll_loss(output_t, target_t)
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
optimizer.step() optimizer.step()
@@ -167,7 +167,7 @@ def evaluate_performance(data_loader, discriminator, device="cpu"):
input_t, target_t = input_t.to(device), target_t.to(device) input_t, target_t = input_t.to(device), target_t.to(device)
output_t = discriminator(input_t) output_t = discriminator(input_t)
# sum up batch loss # sum up batch loss
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item() test_loss += nn.functional.nll_loss(output_t, target_t, reduction="sum").item()
# get the index of the max log-probability # get the index of the max log-probability
pred_t = output_t.argmax(dim=1, keepdim=True) pred_t = output_t.argmax(dim=1, keepdim=True)
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import pytest import pytest
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn
import lightning_base import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) lprobs = nn.functional.log_softmax(logits, dim=-1)
smoothed_loss, nll_loss = label_smoothed_nll_loss( smoothed_loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
) )

View File

@@ -10,7 +10,6 @@ from typing import List
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main from finetune import main as ft_main
@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
assert t_logits_slct.size() == s_logits_slct.size() assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = ( loss_ce = (
self.ce_loss_fct( self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1), nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1), nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
) )
* (self.temperature) ** 2 * (self.temperature) ** 2
) )
@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
assert lm_logits.shape[-1] == self.model.config.vocab_size assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id # Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else: else:
lprobs = F.log_softmax(lm_logits, dim=-1) lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss( student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
) )
@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
teacher_states = torch.stack([hidden_states_T[j] for j in matches]) teacher_states = torch.stack([hidden_states_T[j] for j in matches])
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}" assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
if normalize_hidden: if normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:]) student_states = nn.functional.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:]) teacher_states = nn.functional.layer_norm(teacher_states, teacher_states.shape[1:])
mse = F.mse_loss(student_states, teacher_states, reduction="none") mse = nn.functional.mse_loss(student_states, teacher_states, reduction="none")
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
return masked_mse return masked_mse

View File

@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
lm_logits = outputs["logits"] lm_logits = outputs["logits"]
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id # Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
assert lm_logits.shape[-1] == self.vocab_size assert lm_logits.shape[-1] == self.vocab_size
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else: else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss( loss, nll_loss = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
) )

View File

@@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import datasets import datasets
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from packaging import version from packaging import version
from torch import nn
import librosa import librosa
from lang_trans import arabic from lang_trans import arabic

View File

@@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.nn as nn
from datasets import DatasetDict, load_dataset from datasets import DatasetDict, load_dataset
from packaging import version from packaging import version
from torch import nn
import librosa import librosa
from transformers import ( from transformers import (