[Bart] Refactor - fix issues, consistency with the library, naming (#8900)

* remove make on the fly linear embedding

* start refactor

* big first refactor

* save intermediate

* save intermediat

* correct mask issue

* save tests

* refactor padding masks

* make all tests pass

* further refactor

* make pegasus test pass

* fix bool if

* fix leftover tests

* continue

* bart renaming

* delete torchscript test hack

* fix imports in tests

* correct shift

* fix docs and repo cons

* re-add fix for FSTM

* typo in test

* fix typo

* fix another typo

* continue

* hot fix 2 for tf

* small fixes

* refactor types linting

* continue

* finish refactor

* fix import in tests

* better bart names

* further refactor and add test

* delete hack

* apply sylvains and lysandres commens

* small perf improv

* further perf improv

* improv perf

* fix typo

* make style

* small perf improv
This commit is contained in:
Patrick von Platen
2020-12-09 20:55:24 +01:00
committed by GitHub
parent 75627148ee
commit 06971ac4f9
13 changed files with 1079 additions and 823 deletions

View File

@@ -302,6 +302,8 @@ class ModelTesterMixin:
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen)
@@ -386,7 +388,7 @@ class ModelTesterMixin:
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
@@ -1020,7 +1022,6 @@ class ModelTesterMixin:
)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: