Add post_process_depth_estimation for GLPN (#34413)
* add depth postprocessing for GLPN * remove previous temp fix for glpn tests * Style changes for GLPN's `post_process_depth_estimation` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * additional style fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
6cc4a67b3d
commit
a769ed45e1
@@ -14,7 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for GLPN."""
|
"""Image processor class for GLPN."""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...modeling_outputs import DepthEstimatorOutput
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@@ -27,12 +31,17 @@ from ...image_utils import (
|
|||||||
get_image_size,
|
get_image_size,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
is_scaled_image,
|
is_scaled_image,
|
||||||
|
is_torch_available,
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
)
|
)
|
||||||
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -218,3 +227,44 @@ class GLPNImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
data = {"pixel_values": images}
|
data = {"pixel_values": images}
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_depth_estimation(
|
||||||
|
self,
|
||||||
|
outputs: "DepthEstimatorOutput",
|
||||||
|
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||||
|
) -> List[Dict[str, TensorType]]:
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||||
|
Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`DepthEstimatorOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||||
|
predictions.
|
||||||
|
"""
|
||||||
|
requires_backends(self, "torch")
|
||||||
|
|
||||||
|
predicted_depth = outputs.predicted_depth
|
||||||
|
|
||||||
|
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||||
|
for depth, target_size in zip(predicted_depth, target_sizes):
|
||||||
|
if target_size is not None:
|
||||||
|
depth = depth[None, None, ...]
|
||||||
|
depth = torch.nn.functional.interpolate(depth, size=target_size, mode="bicubic", align_corners=False)
|
||||||
|
depth = depth.squeeze()
|
||||||
|
|
||||||
|
results.append({"predicted_depth": depth})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|||||||
@@ -723,20 +723,18 @@ class GLPNForDepthEstimation(GLPNPreTrainedModel):
|
|||||||
|
|
||||||
>>> with torch.no_grad():
|
>>> with torch.no_grad():
|
||||||
... outputs = model(**inputs)
|
... outputs = model(**inputs)
|
||||||
... predicted_depth = outputs.predicted_depth
|
|
||||||
|
|
||||||
>>> # interpolate to original size
|
>>> # interpolate to original size
|
||||||
>>> prediction = torch.nn.functional.interpolate(
|
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||||
... predicted_depth.unsqueeze(1),
|
... outputs,
|
||||||
... size=image.size[::-1],
|
... target_sizes=[(image.height, image.width)],
|
||||||
... mode="bicubic",
|
|
||||||
... align_corners=False,
|
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> # visualize the prediction
|
>>> # visualize the prediction
|
||||||
>>> output = prediction.squeeze().cpu().numpy()
|
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
||||||
>>> depth = Image.fromarray(formatted)
|
>>> depth = depth.detach().cpu().numpy()
|
||||||
|
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
|||||||
@@ -157,14 +157,6 @@ class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
self.model_tester = GLPNModelTester(self)
|
self.model_tester = GLPNModelTester(self)
|
||||||
self.config_tester = GLPNConfigTester(self, config_class=GLPNConfig)
|
self.config_tester = GLPNConfigTester(self, config_class=GLPNConfig)
|
||||||
|
|
||||||
@unittest.skip(reason="Failing after #32550")
|
|
||||||
def test_pipeline_depth_estimation(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Failing after #32550")
|
|
||||||
def test_pipeline_depth_estimation_fp16(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user