Add TFVisionEncoderDecoderModel (#14148)
* Start the work on TFVisionEncoderDecoderModel * Expose TFVisionEncoderDecoderModel * fix import * Add modeling_tf_vision_encoder_decoder to _ignore_modules in get_model_modules() * reorder * Apply the fix for checkpoint loading as in #14016 * remove attention_mask + fix VISION_DUMMY_INPUTS * A minimal change to make TF generate() work for vision models as encoder in encoder-decoder setting * fix wrong condition: shape_list(input_ids) == 2 * add tests * use personal TFViTModel checkpoint (for now) * Add equivalence tests + projection layer * style * make sure projection layer can run * Add examples * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Clean comments (need to work on TODOs for PyTorch models) * Remove TF -> PT in check_pt_tf_equivalence for TFVisionEncoderDecoderModel * fixes * Revert changes in PT code. * Update tests/test_modeling_tf_vision_encoder_decoder.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add test_inference_coco_en for TF test * fix quality * fix name * build doc * add main_input_name * Fix ckpt name in test * fix diff between master and this PR * fix doc * fix style and quality * fix missing doc * fix labels handling * Delete auto.rst * Add the changes done in #14016 * fix prefix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -490,7 +490,7 @@ class TFEncoderDecoderMixin:
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.decoder.vocab_size)
|
||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||
|
||||
outputs = model_2(
|
||||
@@ -650,7 +650,7 @@ class TFGPT2EncoderDecoderModelTest(TFEncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
|
||||
Reference in New Issue
Block a user