Overhaul TF serving signatures + dummy inputs (#23234)
* Let's try autodetecting serving sigs * Don't clobber existing sigs * Change shapes for multiplechoice models * Make default dummy inputs smarter too * Fix missing f-string * Let's YOLO a serving output too * Read __class__.__name__ properly * Don't just pass naked lists in there and expect it to be okay * Code cleanup * Update default serving sig * Clearer error messages * Further updates to the default serving output * make fixup * Update the serving output a bit more * Cleanups and renames, raise errors appropriately when we can't infer inputs * More renames * we're building in a functional context again, yolo * import DUMMY_INPUTS from the right place * import DUMMY_INPUTS from the right place * Support cross-attention in the dummies * Support cross-attention in the dummies * Complete removal of dummy/serving overrides in BERT * Complete removal of dummy/serving overrides in RoBERTa * Obliterate lots and lots of serving sig and dummy overrides * merge type hint changes * Fix for token_type_ids with vocab_size 1 * Add missing property decorator * Fix T5 and hopefully some models that take conv inputs * More signature pruning * Fix T5's signature * Fix Wav2Vec2 signature * Fix LongformerForMultipleChoice input signature * Fix BLIP and LED * Better default serving output error handling * Fix BART dummies * Fix dummies for cross-attention, esp encoder-decoder models * Fix visionencoderdecoder signature * Fix BLIP serving output * Small tweak to BART dummies * Cleanup the ugly parameter inspection line that I used in a few places * committed a breakpoint again * Move the text_dims check * Remove blip_text serving_output * Add decoder_input_ids to the default input sig * Remove all the manual overrides for encoder-decoder model signatures * Tweak longformer/led input sigs * Tweak default serving output * output.keys() -> output * make fixup
This commit is contained in:
@@ -803,23 +803,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
dummy = {"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int64)}
|
||||
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||
if self.config.add_cross_attention:
|
||||
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||
h = tf.random.uniform(shape=shape)
|
||||
dummy["encoder_hidden_states"] = h
|
||||
|
||||
return dummy
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
@@ -991,24 +974,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(
|
||||
self, output: TFBaseModelOutputWithPastAndCrossAttentions
|
||||
) -> TFBaseModelOutputWithPastAndCrossAttentions:
|
||||
output_cache = self.config.use_cache and self.config.is_decoder
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||
cross_attns = None
|
||||
|
||||
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
past_key_values=pkv,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
cross_attentions=cross_attns,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
|
||||
@@ -1084,13 +1049,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
|
||||
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
||||
@@ -1206,19 +1164,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
|
||||
output_cache = self.config.use_cache and self.config.is_decoder
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||
cross_attns = None
|
||||
|
||||
return TFCausalLMOutputWithCrossAttentions(
|
||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||
)
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
|
||||
@@ -1318,13 +1263,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
|
||||
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a multiple choice classification head on top (a linear layer on top of
|
||||
@@ -1343,16 +1281,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||
)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
tf.Tensor with dummy inputs
|
||||
"""
|
||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int64)}
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@@ -1441,24 +1369,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@tf.function(input_signature=[{
|
||||
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
|
||||
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
|
||||
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
|
||||
}])
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
|
||||
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
|
||||
output = self.call(input_ids=inputs)
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
|
||||
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a token classification head on top (a linear layer on top of
|
||||
@@ -1532,13 +1442,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
|
||||
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||
@@ -1625,14 +1528,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
|
||||
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFQuestionAnsweringModelOutput(
|
||||
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
|
||||
)
|
||||
|
||||
{% else %}
|
||||
import random
|
||||
@@ -2777,26 +2672,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
past_key_values=pkv,
|
||||
decoder_hidden_states=dec_hs,
|
||||
decoder_attentions=dec_attns,
|
||||
cross_attentions=cross_attns,
|
||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_attentions=enc_attns,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
|
||||
class BiasLayer(tf.keras.layers.Layer):
|
||||
@@ -2944,26 +2819,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
logits=output.logits,
|
||||
past_key_values=pkv,
|
||||
decoder_hidden_states=dec_hs,
|
||||
decoder_attentions=dec_attns,
|
||||
cross_attentions=cross_attns,
|
||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_attentions=enc_attns,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
|
||||
Reference in New Issue
Block a user