[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -11,6 +11,7 @@ import torch
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from elasticsearch import Elasticsearch # noqa: F401
|
||||
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -116,14 +117,14 @@ class ELI5DatasetQARetriver(Dataset):
|
||||
return self.make_example(idx % self.data.num_rows)
|
||||
|
||||
|
||||
class RetrievalQAEmbedder(torch.nn.Module):
|
||||
class RetrievalQAEmbedder(nn.Module):
|
||||
def __init__(self, sent_encoder, dim):
|
||||
super(RetrievalQAEmbedder, self).__init__()
|
||||
self.sent_encoder = sent_encoder
|
||||
self.output_dim = 128
|
||||
self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False)
|
||||
self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False)
|
||||
self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
|
||||
self.project_q = nn.Linear(dim, self.output_dim, bias=False)
|
||||
self.project_a = nn.Linear(dim, self.output_dim, bias=False)
|
||||
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
|
||||
|
||||
def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
|
||||
# reproduces BERT forward pass with checkpointing
|
||||
|
||||
Reference in New Issue
Block a user