Test model outputs equivalence (#6445)

* Test model outputs equivalence

* Fix failing tests

* From dict to kwargs

* DistilBERT

* Addressing @sgugger and @patrickvonplaten's comments
This commit is contained in:
Lysandre Debut
2020-08-13 11:59:35 -04:00
committed by GitHub
parent 54c687e97c
commit f7cbc13db7
4 changed files with 197 additions and 31 deletions

View File

@@ -21,13 +21,17 @@ import tensorflow as tf
from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
from .modeling_tf_outputs import TFBaseModelOutputWithPooling, TFMaskedLMOutput, TFQuestionAnsweringModelOutput
from .modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
TFQuestionAnsweringModelOutput,
)
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
cast_bool_to_primitive,
get_initializer,
keras_serializable,
shape_list,
@@ -833,33 +837,41 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
]
def call(self, inputs, training=False):
hidden_states, attention_mask, output_attentions, output_hidden_states, padding_len = inputs
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
padding_len=0,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
all_hidden_states = ()
all_attentions = ()
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
hidden_states = layer_outputs[0]
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
outputs = (hidden_states,)
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
@keras_serializable
@@ -992,7 +1004,12 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, output_attentions, output_hidden_states, padding_len],
embedding_output,
attention_mask=extended_attention_mask,
padding_len=padding_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)