Created
April 29, 2025 03:56
-
-
Save jvcleave/2b3c686237b986c255627f91f952f723 to your computer and use it in GitHub Desktop.
convert_depthpro_full.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import logging | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import coremltools as ct | |
import numpy as np | |
from src.depth_pro.depth_pro import DepthProConfig | |
from src.depth_pro.network.decoder import MultiresConvDecoder | |
from src.depth_pro.network.encoder import DepthProEncoder | |
from src.depth_pro.network.fov import FOVNetwork | |
from src.depth_pro.utils import load_rgb | |
from src.depth_pro.depth_pro import create_backbone_model | |
# Optional: Detailed logging | |
logging.basicConfig(level=logging.INFO) | |
# CONFIG to load | |
CONFIG = DepthProConfig( | |
patch_encoder_preset="dinov2l16_384", | |
image_encoder_preset="dinov2l16_384", | |
checkpoint_uri="./checkpoints/depth_pro.pt", | |
decoder_features=256, | |
use_fov_head=True, | |
fov_encoder_preset="dinov2l16_384", | |
) | |
# Set these manually after initialization | |
CONFIG.encoder_scale_size = (384, 384) | |
CONFIG.head_paddings = [1, 0, 1, 0] | |
CONFIG.fov_head_paddings = [1, 1, 1, 0] | |
class DepthDecoder(nn.Module): | |
def __init__(self, head: nn.Module, fov: FOVNetwork, encoder_scale_size): | |
super().__init__() | |
self.head = head | |
self.fov = fov | |
self.encoder_scale_size = encoder_scale_size | |
def forward(self, inputs): | |
x, features, features_0 = inputs | |
if hasattr(self.fov, "encoder"): | |
x = F.interpolate(x, size=self.encoder_scale_size, mode="bilinear", align_corners=False) | |
x = self.fov.encoder(x)[:, 1:].permute(0, 2, 1) | |
lowres_feature = self.fov.downsample(features_0.detach()) | |
x = x.reshape_as(lowres_feature) + lowres_feature | |
else: | |
x = features_0.detach() | |
fov_deg = self.fov.head(x) | |
f_px = 0.5 * torch.tan(np.pi * fov_deg.to(torch.float32) / 360.0) | |
canonical_inverse_depth = self.head(features) | |
inverse_depth = canonical_inverse_depth * f_px | |
depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4) | |
return depth | |
# --- Define the model builder --- | |
class DepthProFullModel(nn.Module): | |
def __init__(self, encoder, decoder, depth_decoder): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.depth_decoder = depth_decoder | |
def forward(self, x): | |
encodings = self.encoder(x) | |
features, features_0 = self.decoder(encodings) | |
depth = self.depth_decoder([x, features, features_0]) | |
return depth | |
def create_scaled_model(config: DepthProConfig): | |
patch_encoder, patch_encoder_config = create_backbone_model(preset=config.patch_encoder_preset) | |
image_encoder, _ = create_backbone_model(preset=config.image_encoder_preset) | |
fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset) | |
dims_encoder = patch_encoder_config.encoder_feature_dims | |
hook_block_ids = patch_encoder_config.encoder_feature_layer_ids | |
encoder = DepthProEncoder( | |
dims_encoder=dims_encoder, | |
patch_encoder=patch_encoder, | |
image_encoder=image_encoder, | |
hook_block_ids=hook_block_ids, | |
decoder_features=config.decoder_features, | |
) | |
decoder = MultiresConvDecoder( | |
dims_encoder=[config.decoder_features] + list(encoder.dims_encoder), | |
dim_decoder=config.decoder_features, | |
) | |
num_features = config.decoder_features | |
fov = FOVNetwork(num_features=num_features, fov_encoder=fov_encoder) | |
# Setup FOV head | |
fov_head0 = [ | |
nn.Conv2d(num_features, num_features // 2, kernel_size=3, stride=2, padding=config.fov_head_paddings[0]), | |
nn.ReLU(True), | |
] | |
fov_head = [ | |
nn.Conv2d(num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=config.fov_head_paddings[1]), | |
nn.ReLU(True), | |
nn.Conv2d(num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=config.fov_head_paddings[2]), | |
nn.ReLU(True), | |
nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=config.fov_head_paddings[3]), | |
] | |
if fov_encoder is not None: | |
fov.encoder = nn.Sequential( | |
fov_encoder, | |
nn.Linear(fov_encoder.embed_dim, num_features // 2) | |
) | |
fov.downsample = nn.Sequential(*fov_head0) | |
else: | |
fov.head = nn.Sequential(*fov_head0 + fov_head) | |
head = nn.Sequential( | |
nn.Conv2d(config.decoder_features, config.decoder_features // 2, kernel_size=3, stride=1, padding=config.head_paddings[0]), | |
nn.ConvTranspose2d( | |
in_channels=config.decoder_features // 2, | |
out_channels=config.decoder_features // 2, | |
kernel_size=2, | |
stride=2, | |
padding=config.head_paddings[1], | |
bias=True, | |
), | |
nn.Conv2d(config.decoder_features // 2, 32, kernel_size=3, stride=1, padding=config.head_paddings[2]), | |
nn.ReLU(True), | |
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=config.head_paddings[3]), | |
nn.ReLU(), | |
) | |
head[4].bias.data.fill_(0) | |
depth_decoder = DepthDecoder(head, fov, config.encoder_scale_size) | |
return DepthProFullModel(encoder, decoder, depth_decoder) | |
# --- Instantiate model --- | |
print("๐ Creating model...") | |
model = create_scaled_model(CONFIG) | |
state_dict = torch.load(CONFIG.checkpoint_uri, map_location="cpu") | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
# --- Prepare dummy input --- | |
dummy_input = torch.rand(1, 3, 1536, 1536) | |
# --- Convert --- | |
print("๐ Converting to Core ML...") | |
traced = torch.jit.trace(model, dummy_input) | |
coreml_model = ct.convert( | |
traced, | |
convert_to="mlprogram", | |
inputs=[ | |
ct.ImageType(name="image", shape=(1, 3, 1536, 1536), color_layout=ct.colorlayout.RGB), | |
], | |
compute_units=ct.ComputeUnit.CPU_AND_GPU, | |
minimum_deployment_target=ct.target.iOS17, | |
compute_precision=ct.precision.FLOAT16, | |
) | |
# --- Save --- | |
out_path = "DepthPro_Full1536.mlpackage" | |
coreml_model.save(out_path) | |
print(f"โ Done! Saved model to {out_path}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment