From 31adbb247c8c3ec248e30f89a3e4278622915ff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 7 Oct 2019 16:43:21 +0200 Subject: [PATCH] add class wireframes for Bert decoder --- transformers/modeling_bert.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 20b49c592f..f2e2dba589 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -331,6 +331,14 @@ class BertEncoderLayer(nn.Module): return outputs +class BertDecoderLayer(nn.Module): + def __init__(self, config): + raise NotImplementedError + + def forward(self, hidden_state, encoder_output): + raise NotImplementedError + + class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() @@ -363,6 +371,14 @@ class BertEncoder(nn.Module): return outputs # last-layer hidden state, (all hidden states), (all attentions) +class BertDecoder(nn.Module): + def __init__(self, config): + raise NotImplementedError + + def forward(self, encoder_output): + raise NotImplementedError + + class BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__()