This commit is contained in:
Julien Plu
2021-01-10 15:10:15 +01:00
committed by GitHub
parent 96f1f74aaf
commit 4f7022d68d
19 changed files with 151 additions and 415 deletions

View File

@@ -594,16 +594,14 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
)
return outputs
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2Model.serving_output
def serving_output(self, output):
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPast(
last_hidden_state=output.last_hidden_state,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,
last_hidden_state=output.last_hidden_state, past_key_values=pkv, hidden_states=hs, attentions=attns
)
@@ -741,17 +739,13 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=transformer_outputs.attentions,
)
# Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.serving_output
def serving_output(self, output):
pkv = tf.convert_to_tensor(output.past_key_values) if self.config.use_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutputWithPast(
logits=output.logits,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,
)
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -910,12 +904,9 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
attentions=transformer_outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)