[style] consistent nn. and nn.functional: part 4 examples (#12156)

* consistent nn. and nn.functional: p4 examples

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:28:24 -07:00
committed by GitHub
parent 372ab9cd6d
commit 88e84186e5
26 changed files with 130 additions and 126 deletions

View File

@@ -11,6 +11,7 @@ import torch
import torch.utils.checkpoint as checkpoint
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset):
return self.make_example(idx % self.data.num_rows)
class RetrievalQAEmbedder(torch.nn.Module):
class RetrievalQAEmbedder(nn.Module):
def __init__(self, sent_encoder, dim):
super(RetrievalQAEmbedder, self).__init__()
self.sent_encoder = sent_encoder
self.output_dim = 128
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
self.project_q = nn.Linear(dim, self.output_dim, bias=False)
self.project_a = nn.Linear(dim, self.output_dim, bias=False)
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
# reproduces BERT forward pass with checkpointing