[PyTorch Bart] Split Bart into different models (#9343)
* first try * remove old template * finish bart * finish mbart * delete unnecessary line * init pegasus * save intermediate * correct pegasus * finish pegasus * remove cookie cutter leftover * add marian * finish blenderbot * replace in file * correctly split blenderbot * delete "old" folder * correct "add statement" * adapt config for tf comp * correct configs for tf * remove ipdb * fix more stuff * fix mbart * push pegasus fix * fix mbart * more fixes * fix research projects code * finish docs for bart, mbart, and marian * delete unnecessary file * correct attn typo * correct configs * remove pegasus for seq class * correct peg docs * correct peg docs * finish configs * further improve docs * add copied from statements to mbart * fix copied from in mbart * add copy statements to marian * add copied from to marian * add pegasus copied from * finish pegasus * finish copied from * Apply suggestions from code review * make style * backward comp blenderbot * apply lysandres and sylvains suggestions * apply suggestions * push last fixes * fix docs * fix tok tests * fix imports code style * fix doc
This commit is contained in:
committed by
GitHub
parent
4eec5d0cf6
commit
eef66035a2
@@ -246,16 +246,16 @@ TOLERANCE = 1e-4
|
||||
class TFBartModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_no_head(self):
|
||||
model = TFBartModel.from_pretrained("facebook/bart-large", from_pt=True)
|
||||
|
||||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||
# with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, model.config.pad_token_id), tf.int8)
|
||||
output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
expected_shape = (1, 11, 1024)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = tf.Tensor(
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
|
||||
)
|
||||
self.assertTrue(tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
|
||||
|
||||
def test_cnn_summarization_same_as_fairseq_hard(self):
|
||||
hf = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn", from_pt=True)
|
||||
|
||||
Reference in New Issue
Block a user