add class wireframes for Bert decoder
This commit is contained in:
@@ -331,6 +331,14 @@ class BertEncoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, hidden_state, encoder_output):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
@@ -363,6 +371,14 @@ class BertEncoder(nn.Module):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
|
class BertDecoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, encoder_output):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class BertPooler(nn.Module):
|
class BertPooler(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertPooler, self).__init__()
|
super(BertPooler, self).__init__()
|
||||||
|
|||||||
Reference in New Issue
Block a user