Reformat (#9482)
This commit is contained in:
@@ -777,6 +777,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
@@ -885,15 +886,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFMaskedLMOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
||||
@@ -993,15 +991,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFCausalLMOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
@@ -1114,15 +1109,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -1258,15 +1250,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFMultipleChoiceModelOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -1357,15 +1346,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFTokenClassifierOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -1470,15 +1456,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
start_logits=output.start_logits,
|
||||
end_logits=output.end_logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
|
||||
)
|
||||
|
||||
{% else %}
|
||||
@@ -2454,6 +2438,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -2616,6 +2601,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
Reference in New Issue
Block a user