Overhaul TF serving signatures + dummy inputs (#23234)

* Let's try autodetecting serving sigs

* Don't clobber existing sigs

* Change shapes for multiplechoice models

* Make default dummy inputs smarter too

* Fix missing f-string

* Let's YOLO a serving output too

* Read __class__.__name__ properly

* Don't just pass naked lists in there and expect it to be okay

* Code cleanup

* Update default serving sig

* Clearer error messages

* Further updates to the default serving output

* make fixup

* Update the serving output a bit more

* Cleanups and renames, raise errors appropriately when we can't infer inputs

* More renames

* we're building in a functional context again, yolo

* import DUMMY_INPUTS from the right place

* import DUMMY_INPUTS from the right place

* Support cross-attention in the dummies

* Support cross-attention in the dummies

* Complete removal of dummy/serving overrides in BERT

* Complete removal of dummy/serving overrides in RoBERTa

* Obliterate lots and lots of serving sig and dummy overrides

* merge type hint changes

* Fix for token_type_ids with vocab_size 1

* Add missing property decorator

* Fix T5 and hopefully some models that take conv inputs

* More signature pruning

* Fix T5's signature

* Fix Wav2Vec2 signature

* Fix LongformerForMultipleChoice input signature

* Fix BLIP and LED

* Better default serving output error handling

* Fix BART dummies

* Fix dummies for cross-attention, esp encoder-decoder models

* Fix visionencoderdecoder signature

* Fix BLIP serving output

* Small tweak to BART dummies

* Cleanup the ugly parameter inspection line that I used in a few places

* committed a breakpoint again

* Move the text_dims check

* Remove blip_text serving_output

* Add decoder_input_ids to the default input sig

* Remove all the manual overrides for encoder-decoder model signatures

* Tweak longformer/led input sigs

* Tweak default serving output

* output.keys() -> output

* make fixup
This commit is contained in:
Matt
2023-05-24 17:03:24 +01:00
committed by GitHub
parent 3d7baef114
commit 814de8fac7
65 changed files with 274 additions and 4143 deletions

View File

@@ -803,23 +803,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummy = {"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int64)}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h
return dummy
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
@@ -991,24 +974,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
return outputs
def serving_output(
self, output: TFBaseModelOutputWithPastAndCrossAttentions
) -> TFBaseModelOutputWithPastAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,
cross_attentions=cross_attns,
)
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
@@ -1084,13 +1049,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
@@ -1206,19 +1164,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
cross_attentions=outputs.cross_attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
)
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
@@ -1318,13 +1263,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a multiple choice classification head on top (a linear layer on top of
@@ -1343,16 +1281,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int64)}
@unpack_inputs
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
@@ -1441,24 +1369,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
attentions=outputs.attentions,
)
@tf.function(input_signature=[{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}])
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
output = self.call(input_ids=inputs)
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a token classification head on top (a linear layer on top of
@@ -1532,13 +1442,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
@@ -1625,14 +1528,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
)
{% else %}
import random
@@ -2777,26 +2672,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
return outputs
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqModelOutput(
last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(tf.keras.layers.Layer):
@@ -2944,26 +2819,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
encoder_attentions=outputs.encoder_attentions, # 2 of e out
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,