Test model outputs equivalence (#6445)
* Test model outputs equivalence * Fix failing tests * From dict to kwargs * DistilBERT * Addressing @sgugger and @patrickvonplaten's comments
This commit is contained in:
@@ -21,13 +21,17 @@ import tensorflow as tf
|
|||||||
from .configuration_longformer import LongformerConfig
|
from .configuration_longformer import LongformerConfig
|
||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
|
from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput
|
||||||
from .modeling_tf_outputs import TFBaseModelOutputWithPooling, TFMaskedLMOutput, TFQuestionAnsweringModelOutput
|
from .modeling_tf_outputs import (
|
||||||
|
TFBaseModelOutput,
|
||||||
|
TFBaseModelOutputWithPooling,
|
||||||
|
TFMaskedLMOutput,
|
||||||
|
TFQuestionAnsweringModelOutput,
|
||||||
|
)
|
||||||
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
|
from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFMaskedLanguageModelingLoss,
|
TFMaskedLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -833,33 +837,41 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||||||
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
|
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(
|
||||||
hidden_states, attention_mask, output_attentions, output_hidden_states, padding_len = inputs
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
padding_len=0,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
|
||||||
all_hidden_states = ()
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = ()
|
all_attentions = () if output_attentions else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||||
|
|
||||||
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
|
layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
if not return_dict:
|
||||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||||
outputs = outputs + (all_hidden_states,)
|
return TFBaseModelOutput(
|
||||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||||
outputs = outputs + (all_attentions,)
|
)
|
||||||
return outputs # outputs, (hidden states), (attentions)
|
|
||||||
|
|
||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
@@ -992,7 +1004,12 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
[embedding_output, extended_attention_mask, output_attentions, output_hidden_states, padding_len],
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
padding_len=padding_len,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import os.path
|
|||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
||||||
@@ -37,6 +37,11 @@ if is_torch_available():
|
|||||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
top_k_top_p_filtering,
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -63,14 +68,39 @@ class ModelTesterMixin:
|
|||||||
test_chunking = False
|
test_chunking = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
return {
|
inputs_dict = {
|
||||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||||
else v
|
else v
|
||||||
for k, v in inputs_dict.items()
|
for k, v in inputs_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||||
|
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||||
|
inputs_dict["start_positions"] = torch.zeros(
|
||||||
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
|
)
|
||||||
|
inputs_dict["end_positions"] = torch.zeros(
|
||||||
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
|
)
|
||||||
|
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||||
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||||
|
)
|
||||||
|
elif model_class in [
|
||||||
|
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||||
|
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||||
|
*MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||||
|
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||||
|
]:
|
||||||
|
inputs_dict["labels"] = torch.zeros(
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||||
|
)
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
@@ -663,6 +693,64 @@ class ModelTesterMixin:
|
|||||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||||
|
with torch.no_grad():
|
||||||
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
|
|
||||||
|
def recursive_check(tuple_object, dict_object):
|
||||||
|
if isinstance(tuple_object, (List, Tuple)):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif tuple_object is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||||
|
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}",
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(
|
||||||
|
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||||
|
)
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ class T5ModelTester:
|
|||||||
self.batch_size = 13
|
self.batch_size = 13
|
||||||
self.encoder_seq_length = 7
|
self.encoder_seq_length = 7
|
||||||
self.decoder_seq_length = 9
|
self.decoder_seq_length = 9
|
||||||
|
# For common tests
|
||||||
|
self.seq_length = self.decoder_seq_length
|
||||||
self.is_training = True
|
self.is_training = True
|
||||||
self.use_attention_mask = True
|
self.use_attention_mask = True
|
||||||
self.use_labels = True
|
self.use_labels = True
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||||
@@ -78,6 +79,8 @@ class TFModelTesterMixin:
|
|||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
||||||
@@ -88,20 +91,21 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
if return_labels:
|
if return_labels:
|
||||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
|
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
|
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
|
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
|
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||||
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
|
elif model_class in [
|
||||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||||
elif model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||||
elif model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.values():
|
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
]:
|
||||||
elif model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values():
|
inputs_dict["labels"] = tf.zeros(
|
||||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||||
|
)
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
@@ -517,6 +521,61 @@ class TFModelTesterMixin:
|
|||||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||||
|
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
|
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
|
|
||||||
|
def recursive_check(tuple_object, dict_object):
|
||||||
|
if isinstance(tuple_object, (List, Tuple)):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif tuple_object is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.assertTrue(
|
||||||
|
all(tf.equal(tuple_object, dict_object)),
|
||||||
|
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(
|
||||||
|
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||||
|
)
|
||||||
|
|
||||||
def _get_embeds(self, wte, input_ids):
|
def _get_embeds(self, wte, input_ids):
|
||||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||||
# so we try a few of them.
|
# so we try a few of them.
|
||||||
|
|||||||
Reference in New Issue
Block a user