[Seq2Seq] Fix a couple of bugs and clean examples (#7474)
* clean T5 * fix t5 tests * fix index typo * fix tf common test * fix examples * change positional ordering for Bart and FSTM * add signature test * clean docs and add tests * add docs to encoder decoder * clean docs * correct two doc strings * remove sig test for TF Elektra & Funnel * fix tf t5 slow tests * fix input_ids to inputs in tf * Update src/transformers/modeling_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * implement lysandre results * make style * fix encoder decoder typo * fix tf slow tests * fix slow tests * renaming * remove unused input Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a42f62d34f
commit
62f5ae68ec
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os.path
|
||||
import random
|
||||
import tempfile
|
||||
@@ -158,6 +159,28 @@ class ModelTesterMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"encoder_outputs",
|
||||
]
|
||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
||||
else:
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
@@ -187,7 +210,7 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), return_dict=True)
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
@@ -272,10 +295,22 @@ class ModelTesterMixin:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"] # Let's keep only input_ids
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
try:
|
||||
traced_gpt2 = torch.jit.trace(model, inputs)
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
else:
|
||||
input_ids = inputs["input_ids"]
|
||||
traced_model = torch.jit.trace(model, input_ids)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
@@ -283,7 +318,7 @@ class ModelTesterMixin:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_gpt2, pt_file_name)
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user