[FlaxBert] Add ForCausalLM (#16995)
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
This commit is contained in:
@@ -33,6 +33,7 @@ if is_flax_available():
|
||||
AutoTokenizer,
|
||||
EncoderDecoderConfig,
|
||||
FlaxBartForCausalLM,
|
||||
FlaxBertForCausalLM,
|
||||
FlaxBertModel,
|
||||
FlaxEncoderDecoderModel,
|
||||
FlaxGPT2LMHeadModel,
|
||||
@@ -545,6 +546,43 @@ class FlaxBartEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBertEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = FlaxBertModel(config)
|
||||
decoder_model = FlaxBertForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
|
||||
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxEncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
|
||||
Reference in New Issue
Block a user