@@ -23,7 +23,7 @@ import tensorflow as tf
|
||||
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
||||
from .modeling_tf_utils import TFMainLayer, TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -164,9 +164,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
return outputs
|
||||
|
||||
|
||||
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
class TFCTRLMainLayer(TFMainLayer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(config, **kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
Reference in New Issue
Block a user