Add FlaxVisionEncoderDecoderModel (#13359)
* Start the work on FlaxVisionEncoderDecoderModel * Add FlaxVisionEncoderDecoderModel * Add VisionEncoderDecoderConfig * Make FlaxVisionEncoderDecoderModel visible to transformers * Add test * Fix wrong getattr usage * Fix tests * Add FlaxAutoModelForVision2Seq * Expose FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING * clean-up * add integration test * update expected logits * update expected scores * Add ViT2GPT2ModelIntegrationTest + some cleaning * Add projection layer + PT/Flax equivalence tests * Fix import * minor changes * make test slow again * Apply suggestions * Add modeling_flax_vision_encoder_decoder to _ignore_modules in get_model_modules() * fix copies * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * split long strings in multiple lines * decoder_input_ids can't be None * Add back test_configuration_tie * Remove attention_mask parameter * fix test - encoder_last_hidden_state should be encoder_outputs.last_hidden_state instead of the projected vector * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Remove more encoder_attention_mask * remove encoder_attention_mask when calling self.decode (in FlaxVisionEncoderDecoderModule) * Fix style + pass 1s instead of None as encoder_attention_mask * fix init_weights * pass None for encoder_attention_mask * pass 1s instead of None as encoder_attention_mask * Fix doc style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -29,7 +29,6 @@ from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
EncoderDecoderConfig,
|
||||
FlaxBertModel,
|
||||
@@ -350,12 +349,6 @@ class FlaxEncoderDecoderModelTest(unittest.TestCase):
|
||||
def get_from_encoderdecoder_pretrained_model(self):
|
||||
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
|
||||
|
||||
def get_decoder_config(self):
|
||||
config = AutoConfig.from_pretrained("gpt2")
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
return config
|
||||
|
||||
def _check_configuration_tie(self, model):
|
||||
|
||||
module = model.module.bind(model.params)
|
||||
|
||||
Reference in New Issue
Block a user