Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import argparse | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
from gradio_depth2image import init_model, process | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_image", type=str, required=True) | |
parser.add_argument("--input_depth", type=str, required=True) | |
parser.add_argument("--prompt", type=str, required=True) | |
# defaults | |
parser.add_argument("--a_prompt", type=str, default="best quality, extremely detailed") | |
parser.add_argument("--n_prompt", type=str, default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality") | |
parser.add_argument("--num_samples", type=int, default=1) | |
parser.add_argument("--image_resolution", type=int, default=768) | |
parser.add_argument("--detect_resolution", type=int, default=768) | |
parser.add_argument("--ddim_steps", type=int, default=20) | |
parser.add_argument("--scale", type=float, default=7.5) | |
parser.add_argument("--seed", type=int, default=0) | |
parser.add_argument("--eta", type=float, default=0.0) | |
args = parser.parse_args() | |
model, ddim_sampler = init_model() | |
input_image = cv2.imread(args.input_image) | |
depth_image = cv2.imread(args.input_depth) | |
depth_min, depth_max = depth_image[depth_image != 0].min(), depth_image[depth_image != 0].max() | |
depth_pad = 10 | |
assert depth_pad < depth_min | |
depth_value = depth_image[depth_image != 0].astype(np.float32) | |
depth_value = depth_max - depth_value | |
depth_value /= (depth_max - depth_min) | |
depth_value = depth_value * (depth_max - depth_min) + depth_min | |
depth_image[depth_image != 0] = depth_value.astype(np.uint8) | |
depth_image[depth_image == 0] = depth_pad # not completely black | |
# depth_image = None | |
outputs = process( | |
model, ddim_sampler, | |
input_image, args.prompt, args.a_prompt, args.n_prompt, | |
args.num_samples, args.image_resolution, args.detect_resolution, | |
args.ddim_steps, args.scale, args.seed, args.eta, depth_image | |
) | |
for i in range(args.num_samples): | |
out = outputs[i] | |
Image.fromarray(out).save("control_depth_{}.png".format(i)) | |
out[depth_image == depth_pad] = 255 # crop output | |
Image.fromarray(out).save("control_depth_{}_cropped.png".format(i)) | |