[ProphetNet] Bart-like Refactor (#10501)

* first step to refactor

* make all fast tests pass

* make all slow tests pass

* save intermediate

* correct cache

* finish PR

* make fp16 work
This commit is contained in:
Patrick von Platen
2021-03-04 23:27:12 +03:00
committed by GitHub
parent 6290169eb3
commit c503a1c15e
3 changed files with 245 additions and 187 deletions

View File

@@ -243,7 +243,7 @@ class ProphetNetModelTester:
# There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention
self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention
def create_and_check_with_lm_head(
self,