Add FlaxVisionEncoderDecoderModel (#13359)
* Start the work on FlaxVisionEncoderDecoderModel * Add FlaxVisionEncoderDecoderModel * Add VisionEncoderDecoderConfig * Make FlaxVisionEncoderDecoderModel visible to transformers * Add test * Fix wrong getattr usage * Fix tests * Add FlaxAutoModelForVision2Seq * Expose FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING * clean-up * add integration test * update expected logits * update expected scores * Add ViT2GPT2ModelIntegrationTest + some cleaning * Add projection layer + PT/Flax equivalence tests * Fix import * minor changes * make test slow again * Apply suggestions * Add modeling_flax_vision_encoder_decoder to _ignore_modules in get_model_modules() * fix copies * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * split long strings in multiple lines * decoder_input_ids can't be None * Add back test_configuration_tie * Remove attention_mask parameter * fix test - encoder_last_hidden_state should be encoder_outputs.last_hidden_state instead of the projected vector * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Remove more encoder_attention_mask * remove encoder_attention_mask when calling self.decode (in FlaxVisionEncoderDecoderModule) * Fix style + pass 1s instead of None as encoder_attention_mask * fix init_weights * pass None for encoder_attention_mask * pass 1s instead of None as encoder_attention_mask * Fix doc style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -34,6 +34,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertLMHeadModel,
|
||||
DeiTModel,
|
||||
TrOCRForCausalLM,
|
||||
@@ -48,7 +49,7 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TrOCRProcessor
|
||||
from transformers import TrOCRProcessor, ViTFeatureExtractor
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -656,3 +657,69 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_coco_en(self):
|
||||
|
||||
loc = "ydshieh/vit-gpt2-coco-en"
|
||||
|
||||
feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
|
||||
tokenizer = AutoTokenizer.from_pretrained(loc)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(loc)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
img = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(pixel_values, decoder_input_ids)[0].detach().cpu().numpy()
|
||||
|
||||
# verify the logits
|
||||
expected_shape = (1, 1, model.config.decoder.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
EXPECTED_LOGIT_SLICE = np.array(
|
||||
[
|
||||
-38.705807,
|
||||
-30.639929,
|
||||
-31.41903,
|
||||
-39.012012,
|
||||
-38.38696,
|
||||
-34.887207,
|
||||
-33.290855,
|
||||
-35.68447,
|
||||
-38.508484,
|
||||
-36.124645,
|
||||
]
|
||||
)
|
||||
max_diff = np.amax(np.abs(logits[0, 0, :10] - EXPECTED_LOGIT_SLICE))
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
def generate_step(pixel_values):
|
||||
|
||||
outputs = model.generate(
|
||||
pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
output_ids = outputs.sequences
|
||||
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
preds = [pred.strip() for pred in preds]
|
||||
|
||||
return preds, outputs.sequences_scores.detach().cpu().numpy()
|
||||
|
||||
preds, scores = generate_step(pixel_values)
|
||||
|
||||
EXPECTED_SCORES = np.array([-0.59562886])
|
||||
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
# should produce
|
||||
# ["a cat laying on top of a couch next to another cat"]
|
||||
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
|
||||
|
||||
Reference in New Issue
Block a user