[ProphetNet] Bart-like Refactor (#10501)
* first step to refactor * make all fast tests pass * make all slow tests pass * save intermediate * correct cache * finish PR * make fp16 work
This commit is contained in:
committed by
GitHub
parent
6290169eb3
commit
c503a1c15e
@@ -243,7 +243,7 @@ class ProphetNetModelTester:
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past
|
||||
self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
|
||||
self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention
|
||||
|
||||
def create_and_check_with_lm_head(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user