Add FlaxBartForCausalLM (#15995)
* add causal lm * add CausalLM tests * Add FlaxBartForCausalLM * Add EncoderDecoder model tests * change docstring * make repo-consistency * suggested changes * remove jax ops * correction * rename pre-trained decoder model
This commit is contained in:
@@ -22,6 +22,7 @@ import numpy as np
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
|
||||
|
||||
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
|
||||
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
|
||||
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
|
||||
from ..test_modeling_flax_common import ids_tensor
|
||||
@@ -31,6 +32,7 @@ if is_flax_available():
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
EncoderDecoderConfig,
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertModel,
|
||||
FlaxEncoderDecoderModel,
|
||||
FlaxGPT2LMHeadModel,
|
||||
@@ -360,6 +362,7 @@ class FlaxEncoderDecoderMixin:
|
||||
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
|
||||
|
||||
# check without `enc_to_dec_proj` projection
|
||||
decoder_config.hidden_size = config.hidden_size
|
||||
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
|
||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||
@@ -456,6 +459,43 @@ class FlaxGPT2EncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase
|
||||
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBartEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxBertModel(config)
|
||||
decoder_model = FlaxBartForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxBartStandaloneDecoderModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxEncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
|
||||
Reference in New Issue
Block a user