Restore TF embeddings and attention layers to their previous version (#9890)
* Refacto BERT * Restore all the concerned models * Remove print * Update template * Apply Sylvain's and Morgan's comments * Fix cast * Put the cast inside call * Remove cond in ebds * Fix funnel * Restore previous dot product (attention_scores) computation * Add ConvBERT and BART * Make all the S2S models ONNX compliant * Fix test * Fix check copies
This commit is contained in:
@@ -866,7 +866,8 @@ class TFModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
inputs = copy.deepcopy(inputs_dict)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
@@ -882,6 +883,8 @@ class TFModelTesterMixin:
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||
|
||||
inputs = self._prepare_for_class(inputs, model_class)
|
||||
|
||||
model(inputs)
|
||||
|
||||
def test_graph_mode_with_inputs_embeds(self):
|
||||
@@ -890,7 +893,8 @@ class TFModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
inputs = copy.deepcopy(inputs_dict)
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
@@ -906,6 +910,8 @@ class TFModelTesterMixin:
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||
|
||||
inputs = self._prepare_for_class(inputs, model_class)
|
||||
|
||||
@tf.function
|
||||
def run_in_graph_mode():
|
||||
return model(inputs)
|
||||
|
||||
Reference in New Issue
Block a user