This commit is contained in:
Julien Plu
2021-01-10 15:10:15 +01:00
committed by GitHub
parent 96f1f74aaf
commit 4f7022d68d
19 changed files with 151 additions and 415 deletions

View File

@@ -777,6 +777,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
return outputs
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -885,15 +886,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
@@ -993,15 +991,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@@ -1114,15 +1109,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1258,15 +1250,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMultipleChoiceModelOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1357,15 +1346,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1470,15 +1456,13 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits,
end_logits=output.end_logits,
hidden_states=hs,
attentions=attns,
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
)
{% else %}
@@ -2454,6 +2438,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
encoder_attentions=inputs["encoder_outputs"].attentions,
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
@@ -2616,6 +2601,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
encoder_attentions=outputs.encoder_attentions, # 2 of e out
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None