New serving (#9419)

* Add a serving method

* Add albert

* Add serving for BERT and BART

* Add more models

* Finish the serving addition

* Temp fix

* Restore DPR

* Fix funnel attribute

* Fix attributes GPT2

* Fix OpenAIGPT attribute

* Fix T5 attributes

* Fix Bart attributes

* Fix TransfoXL attributes

* Add versioning

* better test

* Update template

* Fix Flaubert

* Fix T5

* Apply style

* Remove unused imports

* Deactivate extra parameters

* Remove too long test + saved_model default to False

* Ignore the saved model test for some models

* Fix some inputs

* Fix mpnet serving

* Trigger CI

* Address all comments
This commit is contained in:
Julien Plu
2021-01-07 11:48:49 +01:00
committed by GitHub
parent 390cf16bc8
commit 812045adcc
36 changed files with 1773 additions and 68 deletions

View File

@@ -776,6 +776,16 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
)
return outputs
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
@@ -874,6 +884,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
@@ -972,6 +992,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
@@ -1083,6 +1113,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings(
@@ -1207,6 +1247,27 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@tf.function(input_signature=[{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}])
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMultipleChoiceModelOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings(
"""{{cookiecutter.modelname}} Model with a token classification head on top (a linear layer on top of
@@ -1295,6 +1356,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings(
@@ -1398,6 +1469,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits,
end_logits=output.end_logits,
hidden_states=hs,
attentions=attns,
)
{% else %}
import math
@@ -1792,6 +1874,21 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
"input_ids": input_ids,
}
return dummy_inputs
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
@@ -2356,6 +2453,23 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqModelOutput(
last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def get_input_embeddings(self):
return self.shared
@@ -2501,6 +2615,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"