Spaces:
Sleeping
Sleeping
File size: 2,382 Bytes
ba90c74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import torch
import argparse
import cv2
from PIL import Image
import numpy as np
from gradio_depth2image import init_model, process
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_image", type=str, required=True)
parser.add_argument("--input_depth", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
# defaults
parser.add_argument("--a_prompt", type=str, default="best quality, extremely detailed")
parser.add_argument("--n_prompt", type=str, default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality")
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--image_resolution", type=int, default=768)
parser.add_argument("--detect_resolution", type=int, default=768)
parser.add_argument("--ddim_steps", type=int, default=20)
parser.add_argument("--scale", type=float, default=7.5)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eta", type=float, default=0.0)
args = parser.parse_args()
model, ddim_sampler = init_model()
input_image = cv2.imread(args.input_image)
depth_image = cv2.imread(args.input_depth)
depth_min, depth_max = depth_image[depth_image != 0].min(), depth_image[depth_image != 0].max()
depth_pad = 10
assert depth_pad < depth_min
depth_value = depth_image[depth_image != 0].astype(np.float32)
depth_value = depth_max - depth_value
depth_value /= (depth_max - depth_min)
depth_value = depth_value * (depth_max - depth_min) + depth_min
depth_image[depth_image != 0] = depth_value.astype(np.uint8)
depth_image[depth_image == 0] = depth_pad # not completely black
# depth_image = None
outputs = process(
model, ddim_sampler,
input_image, args.prompt, args.a_prompt, args.n_prompt,
args.num_samples, args.image_resolution, args.detect_resolution,
args.ddim_steps, args.scale, args.seed, args.eta, depth_image
)
for i in range(args.num_samples):
out = outputs[i]
Image.fromarray(out).save("control_depth_{}.png".format(i))
out[depth_image == depth_pad] = 255 # crop output
Image.fromarray(out).save("control_depth_{}_cropped.png".format(i))
|