New serving (#9419)
* Add a serving method * Add albert * Add serving for BERT and BART * Add more models * Finish the serving addition * Temp fix * Restore DPR * Fix funnel attribute * Fix attributes GPT2 * Fix OpenAIGPT attribute * Fix T5 attributes * Fix Bart attributes * Fix TransfoXL attributes * Add versioning * better test * Update template * Fix Flaubert * Fix T5 * Apply style * Remove unused imports * Deactivate extra parameters * Remove too long test + saved_model default to False * Ignore the saved model test for some models * Fix some inputs * Fix mpnet serving * Trigger CI * Address all comments
This commit is contained in:
@@ -594,6 +594,18 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
)
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFBaseModelOutputWithPast(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
past_key_values=pkv,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
|
||||
|
||||
class TFCTRLLMHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, input_embeddings, **kwargs):
|
||||
@@ -729,6 +741,18 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFCausalLMOutputWithPast(
|
||||
logits=output.logits,
|
||||
past_key_values=pkv,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
@@ -885,3 +909,13 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user