Samell fixed in tf template (#7044)
This commit is contained in:
@@ -19,9 +19,6 @@
|
||||
# In this template, replace all the XXX (various casings) with your model name
|
||||
####################################################
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_xxx import XxxConfig
|
||||
@@ -47,12 +44,14 @@ from .modeling_tf_utils import (
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "XXXConfig"
|
||||
_TOKENIZER_FOR_DOC = "XxxTokenizer"
|
||||
@@ -115,6 +114,7 @@ class TFXxxLayer(tf.keras.layers.Layer):
|
||||
# The full model without a specific pretrained or finetuning head is
|
||||
# provided as a tf.keras.layers.Layer usually called "TFXxxMainLayer"
|
||||
####################################################
|
||||
@keras_serializable
|
||||
class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user