Fix docstrings for TF BLIP (#22618)
* Fix docstrings for TFBLIP * Fix missing line in TF port! * Use values from torch tests now other bugs fixed * Use values from torch tests now other bugs fixed * Fix doctest string
This commit is contained in:
@@ -1020,7 +1020,7 @@ class TFBlipModel(TFBlipPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pooled_output = text_outputs[1]
|
pooled_output = text_outputs[1]
|
||||||
text_features = self.text_projection(pooled_output)
|
text_features = self.blip.text_projection(pooled_output)
|
||||||
|
|
||||||
return text_features
|
return text_features
|
||||||
|
|
||||||
@@ -1057,7 +1057,7 @@ class TFBlipModel(TFBlipPreTrainedModel):
|
|||||||
vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict)
|
vision_outputs = self.blip.vision_model(pixel_values=pixel_values, return_dict=return_dict)
|
||||||
|
|
||||||
pooled_output = vision_outputs[1] # pooled_output
|
pooled_output = vision_outputs[1] # pooled_output
|
||||||
image_features = self.visual_projection(pooled_output)
|
image_features = self.blip.visual_projection(pooled_output)
|
||||||
|
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
@@ -1238,7 +1238,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
|
|||||||
|
|
||||||
>>> outputs = model.generate(**inputs)
|
>>> outputs = model.generate(**inputs)
|
||||||
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
||||||
two cats are laying on a couch
|
two cats sleeping on a couch
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1410,7 +1410,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
|
|||||||
>>> inputs["labels"] = labels
|
>>> inputs["labels"] = labels
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> loss = outputs.loss
|
>>> loss = outputs.loss
|
||||||
>>> loss.backward()
|
|
||||||
|
|
||||||
>>> # inference
|
>>> # inference
|
||||||
>>> text = "How many cats are in the picture?"
|
>>> text = "How many cats are in the picture?"
|
||||||
|
|||||||
@@ -462,6 +462,7 @@ class TFBlipTextEncoder(tf.keras.layers.Layer):
|
|||||||
next_decoder_cache += (layer_outputs[-1],)
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|||||||
@@ -783,7 +783,7 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
|
|||||||
# Test output
|
# Test output
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
predictions[0].numpy().tolist(),
|
predictions[0].numpy().tolist(),
|
||||||
[30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102],
|
[30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_inference_vqa(self):
|
def test_inference_vqa(self):
|
||||||
@@ -810,6 +810,6 @@ class TFBlipModelIntegrationTest(unittest.TestCase):
|
|||||||
out_itm = model(**inputs)
|
out_itm = model(**inputs)
|
||||||
out = model(**inputs, use_itm_head=False, training=False)
|
out = model(**inputs, use_itm_head=False, training=False)
|
||||||
|
|
||||||
expected_scores = tf.convert_to_tensor([[0.9798, 0.0202]])
|
expected_scores = tf.convert_to_tensor([[0.0029, 0.9971]])
|
||||||
self.assertTrue(np.allclose(tf.nn.softmax(out_itm[0]).numpy(), expected_scores, rtol=1e-3, atol=1e-3))
|
self.assertTrue(np.allclose(tf.nn.softmax(out_itm[0]).numpy(), expected_scores, rtol=1e-3, atol=1e-3))
|
||||||
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5053]]), rtol=1e-3, atol=1e-3))
|
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5162]]), rtol=1e-3, atol=1e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user