[Bart] Refactor - fix issues, consistency with the library, naming (#8900)
* remove make on the fly linear embedding * start refactor * big first refactor * save intermediate * save intermediat * correct mask issue * save tests * refactor padding masks * make all tests pass * further refactor * make pegasus test pass * fix bool if * fix leftover tests * continue * bart renaming * delete torchscript test hack * fix imports in tests * correct shift * fix docs and repo cons * re-add fix for FSTM * typo in test * fix typo * fix another typo * continue * hot fix 2 for tf * small fixes * refactor types linting * continue * finish refactor * fix import in tests * better bart names * further refactor and add test * delete hack * apply sylvains and lysandres commens * small perf improv * further perf improv * improv perf * fix typo * make style * small perf improv
This commit is contained in:
committed by
GitHub
parent
75627148ee
commit
06971ac4f9
@@ -302,6 +302,8 @@ class ModelTesterMixin:
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
@@ -386,7 +388,7 @@ class ModelTesterMixin:
|
||||
|
||||
try:
|
||||
if model.config.is_encoder_decoder:
|
||||
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved
|
||||
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
decoder_input_ids = inputs["decoder_input_ids"]
|
||||
@@ -1020,7 +1022,6 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user