Test model outputs equivalence (#6445)
* Test model outputs equivalence * Fix failing tests * From dict to kwargs * DistilBERT * Addressing @sgugger and @patrickvonplaten's comments
This commit is contained in:
@@ -21,13 +21,17 @@ import tensorflow as tf
|
||||
from .configuration_longformer import LongformerConfig
|
||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
|
||||
from .modeling_tf_outputs import TFBaseModelOutputWithPooling, TFMaskedLMOutput, TFQuestionAnsweringModelOutput
|
||||
from .modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
TFMaskedLMOutput,
|
||||
TFQuestionAnsweringModelOutput,
|
||||
)
|
||||
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
|
||||
from .modeling_tf_utils import (
|
||||
TFMaskedLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFQuestionAnsweringLoss,
|
||||
cast_bool_to_primitive,
|
||||
get_initializer,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
@@ -833,33 +837,41 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
hidden_states, attention_mask, output_attentions, output_hidden_states, padding_len = inputs
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
padding_len=0,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
):
|
||||
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||
|
||||
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
||||
if output_hidden_states:
|
||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # outputs, (hidden states), (attentions)
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
@@ -992,7 +1004,12 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
encoder_outputs = self.encoder(
|
||||
[embedding_output, extended_attention_mask, output_attentions, output_hidden_states, padding_len],
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
padding_len=padding_len,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user