Reformat (#9482)
This commit is contained in:
@@ -594,16 +594,14 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
|
||||
)
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2Model.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFBaseModelOutputWithPast(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
past_key_values=pkv,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
last_hidden_state=output.last_hidden_state, past_key_values=pkv, hidden_states=hs, attentions=attns
|
||||
)
|
||||
|
||||
|
||||
@@ -741,17 +739,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFCausalLMOutputWithPast(
|
||||
logits=output.logits,
|
||||
past_key_values=pkv,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@@ -910,12 +904,9 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
|
||||
def serving_output(self, output):
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
logits=output.logits,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
Reference in New Issue
Block a user