|
import gc |
|
import os |
|
import numpy as np |
|
import torch |
|
|
|
from diffusers.training_utils import set_seed |
|
|
|
|
|
from DepthCrafter.depthcrafter.depth_crafter_ppl import DepthCrafterPipeline |
|
from DepthCrafter.depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter |
|
|
|
class DepthCrafterDemo: |
|
def __init__( |
|
self, |
|
unet_path: str, |
|
pre_train_path: str, |
|
cpu_offload: str = "model", |
|
device: str = "cuda:0", |
|
): |
|
unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained( |
|
unet_path, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
self.pipe = DepthCrafterPipeline.from_pretrained( |
|
pre_train_path, |
|
unet=unet, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
|
|
|
|
if cpu_offload is not None: |
|
if cpu_offload == "sequential": |
|
|
|
self.pipe.enable_sequential_cpu_offload() |
|
elif cpu_offload == "model": |
|
self.pipe.enable_model_cpu_offload() |
|
else: |
|
raise ValueError(f"Unknown cpu offload option: {cpu_offload}") |
|
else: |
|
self.pipe.to(device) |
|
|
|
try: |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
except Exception as e: |
|
print(e) |
|
print("Xformers is not enabled") |
|
self.pipe.enable_attention_slicing() |
|
|
|
def infer( |
|
self, |
|
frames, |
|
near, |
|
far, |
|
num_denoising_steps: int, |
|
guidance_scale: float, |
|
window_size: int = 110, |
|
overlap: int = 25, |
|
seed: int = 42, |
|
track_time: bool = True, |
|
): |
|
set_seed(seed) |
|
|
|
|
|
with torch.inference_mode(): |
|
res = self.pipe( |
|
frames, |
|
height=frames.shape[1], |
|
width=frames.shape[2], |
|
output_type="np", |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_denoising_steps, |
|
window_size=window_size, |
|
overlap=overlap, |
|
track_time=track_time, |
|
).frames[0] |
|
|
|
res = res.sum(-1) / res.shape[-1] |
|
|
|
depths = (res - res.min()) / (res.max() - res.min()) |
|
|
|
|
|
|
|
depths = torch.from_numpy(depths).unsqueeze(1) |
|
depths *= 3900 |
|
depths[depths < 1e-5] = 1e-5 |
|
depths = 10000.0 / depths |
|
depths = depths.clip(near, far) |
|
|
|
return depths |
|
|