Spaces:
Sleeping
Sleeping
import gc | |
# get socket and check if the name is vgldgx01 | |
import socket | |
if socket.gethostname() != "vgldgx01": | |
import spaces #[uncomment to use ZeroGPU] | |
import numpy as np | |
import PIL.Image | |
import torch | |
from controlnet_aux.util import HWC3 | |
from diffusers import ( | |
ControlNetModel, | |
DiffusionPipeline, | |
StableDiffusionControlNetPipeline, | |
StableDiffusionImg2ImgPipeline, | |
UniPCMultistepScheduler, | |
DDIMScheduler, #rgb2x | |
) | |
import torchvision | |
from torchvision import transforms | |
from cv_utils import resize_image | |
from preprocessor import Preprocessor | |
from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES | |
from tqdm.auto import tqdm | |
import subprocess | |
from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline | |
from app_texnet import image_to_temp_path | |
import os | |
import time | |
import tempfile | |
CONTROLNET_MODEL_IDS = { | |
# "Openpose": "lllyasviel/control_v11p_sd15_openpose", | |
# "Canny": "lllyasviel/control_v11p_sd15_canny", | |
# "MLSD": "lllyasviel/control_v11p_sd15_mlsd", | |
# "scribble": "lllyasviel/control_v11p_sd15_scribble", | |
# "softedge": "lllyasviel/control_v11p_sd15_softedge", | |
# "segmentation": "lllyasviel/control_v11p_sd15_seg", | |
# "depth": "lllyasviel/control_v11f1p_sd15_depth", | |
# "NormalBae": "lllyasviel/control_v11p_sd15_normalbae", | |
# "lineart": "lllyasviel/control_v11p_sd15_lineart", | |
# "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime", | |
# "shuffle": "lllyasviel/control_v11e_sd15_shuffle", | |
# "ip2p": "lllyasviel/control_v11e_sd15_ip2p", | |
# "inpaint": "lllyasviel/control_v11e_sd15_inpaint", | |
# "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call | |
"texnet": "jingyangcarl/texnet", | |
} | |
def download_all_controlnet_weights() -> None: | |
for model_id in CONTROLNET_MODEL_IDS.values(): | |
ControlNetModel.from_pretrained(model_id) | |
class Model: | |
def __init__( | |
self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny" | |
) -> None: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.base_model_id = "" | |
self.task_name = "" | |
self.pipe = self.load_pipe(base_model_id, task_name) | |
self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained( | |
'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16 | |
).to(self.device) | |
self.preprocessor = Preprocessor() | |
# set up pipe_rgb2x | |
self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained( | |
"zheng95z/rgb-to-x", | |
torch_dtype=torch.float16, | |
).to(self.device) | |
self.pipe_rgb2x.scheduler = DDIMScheduler.from_config( | |
self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" | |
) | |
self.pipe_rgb2x.set_progress_bar_config(disable=True) | |
# setup blender | |
self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender' | |
if not os.path.exists(self.blender_path): | |
print("Downloading Blender...") | |
subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True) | |
subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True) | |
print("Blender downloaded and extracted.") | |
def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline: | |
if ( | |
base_model_id == self.base_model_id | |
and task_name == self.task_name | |
and hasattr(self, "pipe") | |
and self.pipe is not None | |
): | |
return self.pipe | |
model_id = CONTROLNET_MODEL_IDS[task_name] | |
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) | |
to_upload = False | |
if to_upload: | |
# confirm before uploading | |
confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ") | |
if confirm.lower() == "y": | |
controlnet.push_to_hub("jingyangcarl/texnet") | |
else: | |
print("Upload cancelled.") | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.to(self.device) | |
if self.device.type == "cuda": | |
import os | |
if os.environ.get("SPACES_ZERO_GPU", "0") == "1": | |
# when running on ZeroGPU, enable CPU offload | |
# pipe.enable_xformers_memory_efficient_attention() doens't work | |
# pipe.enable_model_cpu_offload() | |
pass | |
else: | |
pipe.enable_xformers_memory_efficient_attention() | |
torch.cuda.empty_cache() | |
gc.collect() | |
self.base_model_id = base_model_id | |
self.task_name = task_name | |
return pipe | |
def set_base_model(self, base_model_id: str) -> str: | |
if not base_model_id or base_model_id == self.base_model_id: | |
return self.base_model_id | |
del self.pipe | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
self.pipe = self.load_pipe(base_model_id, self.task_name) | |
except Exception: # noqa: BLE001 | |
self.pipe = self.load_pipe(self.base_model_id, self.task_name) | |
return self.base_model_id | |
def load_controlnet_weight(self, task_name: str) -> None: | |
if task_name == self.task_name: | |
return | |
if self.pipe is not None and hasattr(self.pipe, "controlnet"): | |
del self.pipe.controlnet | |
torch.cuda.empty_cache() | |
gc.collect() | |
model_id = CONTROLNET_MODEL_IDS[task_name] | |
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) | |
controlnet.to(self.device) | |
torch.cuda.empty_cache() | |
gc.collect() | |
self.pipe.controlnet = controlnet | |
self.task_name = task_name | |
def get_prompt(self, prompt: str, additional_prompt: str) -> str: | |
return additional_prompt if not prompt else f"{prompt}, {additional_prompt}" | |
# @spaces.GPU #[uncomment to use ZeroGPU] | |
def run_pipe( | |
self, | |
prompt: str, | |
negative_prompt: str, | |
control_image: PIL.Image.Image, | |
num_images: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
) -> list[PIL.Image.Image]: | |
generator = torch.Generator().manual_seed(seed) | |
# self.pipe.to(self.device) | |
return self.pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_images, | |
num_inference_steps=num_steps, | |
generator=generator, | |
image=control_image, | |
).images | |
# @spaces.GPU #[uncomment to use ZeroGPU] | |
def process_texnet( | |
self, | |
obj_name: str, | |
represented_image: np.ndarray | None, # not used | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
low_threshold: int, | |
high_threshold: int, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
self.preprocessor.load("texnet") | |
control_image = self.preprocessor( | |
image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil" | |
) | |
self.load_controlnet_weight("texnet") | |
tex_coarse = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
# use img2img pipeline | |
self.pipe_backup = self.pipe | |
self.pipe = self.pipe_base | |
# refine | |
tex_fine = [] | |
for result_coarse in tex_coarse: | |
# clean up GPU cache | |
torch.cuda.empty_cache() | |
gc.collect() | |
# masking | |
mask = (np.array(control_image).sum(axis=-1) == 0)[...,None] | |
image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse)) | |
image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked) | |
result_fine = self.run_pipe( | |
# prompt=prompt, | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=image_blurry, | |
num_images=1, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
)[0] | |
result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine)) | |
tex_fine.append(result_fine) | |
# restore the original pipe | |
self.pipe = self.pipe_backup | |
# use rgb2x for now for generating the texture | |
def rgb2x( | |
pipeline, | |
photo, | |
inference_step = 50, | |
num_samples = 1, | |
): | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop | |
old_height = photo.shape[1] | |
old_width = photo.shape[2] | |
new_height = old_height | |
new_width = old_width | |
radio = old_height / old_width | |
max_side = 1000 | |
if old_height > old_width: | |
new_height = max_side | |
new_width = int(new_height / radio) | |
else: | |
new_width = max_side | |
new_height = int(new_width * radio) | |
if new_width % 8 != 0 or new_height % 8 != 0: | |
new_width = new_width // 8 * 8 | |
new_height = new_height // 8 * 8 | |
photo = torchvision.transforms.Resize((new_height, new_width))(photo) | |
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] | |
prompts = { | |
"albedo": "Albedo (diffuse basecolor)", | |
"normal": "Camera-space Normal", | |
"roughness": "Roughness", | |
"metallic": "Metallicness", | |
"irradiance": "Irradiance (diffuse lighting)", | |
} | |
return_list = [] | |
for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False): | |
for aov_name in required_aovs: | |
prompt = prompts[aov_name] | |
generated_image = pipeline( | |
prompt=prompt, | |
photo=photo, | |
num_inference_steps=inference_step, | |
height=new_height, | |
width=new_width, | |
generator=generator, | |
required_aovs=[aov_name], | |
).images[0][0] | |
generated_image = torchvision.transforms.Resize( | |
(old_height, old_width) | |
)(generated_image) | |
# generated_image = (generated_image, f"Generated {aov_name} {i}") | |
# generated_image = (generated_image, f"{aov_name}") | |
return_list.append(generated_image) | |
return photo, return_list, prompts | |
# Load rgb2x pipeline | |
_, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images) | |
base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color") | |
normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map") | |
roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness") | |
metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic") | |
prompt_nospace = prompt.replace(' ', '_') | |
current_timecode = time.strftime("%Y%m%d_%H%M%S") | |
# output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path | |
output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path | |
os.makedirs(os.path.dirname(output_blend_path), exist_ok=True) | |
def run_blend_generation( | |
blender_path, | |
generate_script_path, | |
obj_path, | |
base_color_path, | |
normal_map_path, | |
roughness_path, | |
metallic_path, | |
output_blend | |
): | |
cmd = [ | |
blender_path, "--background", "--python", generate_script_path, "--", | |
obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend | |
] | |
subprocess.run(cmd, check=True) | |
# check if the blender_path exists, if not download | |
run_blend_generation( | |
blender_path=self.blender_path, | |
generate_script_path="rgb2x/generate_blend.py", | |
obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path | |
base_color_path=base_color_path, | |
normal_map_path=normal_map_path, | |
roughness_path=roughness_path, | |
metallic_path=metallic_path, | |
output_blend=output_blend_path # replace with desired output path | |
) | |
# gallary | |
return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path] | |
# @spaces.GPU #[uncomment to use ZeroGPU] | |
def process_canny( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
low_threshold: int, | |
high_threshold: int, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
self.preprocessor.load("Canny") | |
control_image = self.preprocessor( | |
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution | |
) | |
self.load_controlnet_weight("Canny") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_mlsd( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
value_threshold: float, | |
distance_threshold: float, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
self.preprocessor.load("MLSD") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
thr_v=value_threshold, | |
thr_d=distance_threshold, | |
) | |
self.load_controlnet_weight("MLSD") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_scribble( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
elif preprocessor_name == "HED": | |
self.preprocessor.load(preprocessor_name) | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
scribble=False, | |
) | |
elif preprocessor_name == "PidiNet": | |
self.preprocessor.load(preprocessor_name) | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
safe=False, | |
) | |
self.load_controlnet_weight("scribble") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_scribble_interactive( | |
self, | |
image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
) -> list[PIL.Image.Image]: | |
if image_and_mask is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
image = 255 - image_and_mask["composite"] # type: ignore | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
self.load_controlnet_weight("scribble") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_softedge( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
elif preprocessor_name in ["HED", "HED safe"]: | |
safe = "safe" in preprocessor_name | |
self.preprocessor.load("HED") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
scribble=safe, | |
) | |
elif preprocessor_name in ["PidiNet", "PidiNet safe"]: | |
safe = "safe" in preprocessor_name | |
self.preprocessor.load("PidiNet") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
safe=safe, | |
) | |
else: | |
raise ValueError | |
self.load_controlnet_weight("softedge") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_openpose( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
else: | |
self.preprocessor.load("Openpose") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
hand_and_face=True, | |
) | |
self.load_controlnet_weight("Openpose") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_segmentation( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
else: | |
self.preprocessor.load(preprocessor_name) | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
) | |
self.load_controlnet_weight("segmentation") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_depth( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
else: | |
self.preprocessor.load(preprocessor_name) | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
) | |
self.load_controlnet_weight("depth") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_normal( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
else: | |
self.preprocessor.load("NormalBae") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
) | |
self.load_controlnet_weight("NormalBae") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_lineart( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
preprocess_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name in ["None", "None (anime)"]: | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
elif preprocessor_name in ["Lineart", "Lineart coarse"]: | |
coarse = "coarse" in preprocessor_name | |
self.preprocessor.load("Lineart") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
coarse=coarse, | |
) | |
elif preprocessor_name == "Lineart (anime)": | |
self.preprocessor.load("LineartAnime") | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
detect_resolution=preprocess_resolution, | |
) | |
if "anime" in preprocessor_name: | |
self.load_controlnet_weight("lineart_anime") | |
else: | |
self.load_controlnet_weight("lineart") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_shuffle( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
preprocessor_name: str, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
if preprocessor_name == "None": | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
else: | |
self.preprocessor.load(preprocessor_name) | |
control_image = self.preprocessor( | |
image=image, | |
image_resolution=image_resolution, | |
) | |
self.load_controlnet_weight("shuffle") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |
def process_ip2p( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
additional_prompt: str, | |
negative_prompt: str, | |
num_images: int, | |
image_resolution: int, | |
num_steps: int, | |
guidance_scale: float, | |
seed: int, | |
) -> list[PIL.Image.Image]: | |
if image is None: | |
raise ValueError | |
if image_resolution > MAX_IMAGE_RESOLUTION: | |
raise ValueError | |
if num_images > MAX_NUM_IMAGES: | |
raise ValueError | |
image = HWC3(image) | |
image = resize_image(image, resolution=image_resolution) | |
control_image = PIL.Image.fromarray(image) | |
self.load_controlnet_weight("ip2p") | |
results = self.run_pipe( | |
prompt=self.get_prompt(prompt, additional_prompt), | |
negative_prompt=negative_prompt, | |
control_image=control_image, | |
num_images=num_images, | |
num_steps=num_steps, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
) | |
return [control_image, *results] | |