This commit is contained in:
Julien Plu
2021-01-10 15:10:15 +01:00
committed by GitHub
parent 96f1f74aaf
commit 4f7022d68d
19 changed files with 151 additions and 415 deletions

View File

@@ -1128,11 +1128,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
@@ -1241,11 +1237,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1348,11 +1340,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredi
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFNextSentencePredictorOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFNextSentencePredictorOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1453,11 +1441,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1605,11 +1589,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMultipleChoiceModelOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1715,11 +1695,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(
logits=output.logits,
hidden_states=hs,
attentions=attns,
)
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
@add_start_docstrings(
@@ -1839,8 +1815,5 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits,
end_logits=output.end_logits,
hidden_states=hs,
attentions=attns,
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
)