matgen / model.py
jingyangcarl's picture
change path on server
d9c3c15
import gc
# get socket and check if the name is vgldgx01
import socket
if socket.gethostname() != "vgldgx01":
import spaces #[uncomment to use ZeroGPU]
import numpy as np
import PIL.Image
import torch
from controlnet_aux.util import HWC3
from diffusers import (
ControlNetModel,
DiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionImg2ImgPipeline,
UniPCMultistepScheduler,
DDIMScheduler, #rgb2x
)
import torchvision
from torchvision import transforms
from cv_utils import resize_image
from preprocessor import Preprocessor
from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
from tqdm.auto import tqdm
import subprocess
from rgb2x.pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
from app_texnet import image_to_temp_path
import os
import time
import tempfile
CONTROLNET_MODEL_IDS = {
# "Openpose": "lllyasviel/control_v11p_sd15_openpose",
# "Canny": "lllyasviel/control_v11p_sd15_canny",
# "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
# "scribble": "lllyasviel/control_v11p_sd15_scribble",
# "softedge": "lllyasviel/control_v11p_sd15_softedge",
# "segmentation": "lllyasviel/control_v11p_sd15_seg",
# "depth": "lllyasviel/control_v11f1p_sd15_depth",
# "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
# "lineart": "lllyasviel/control_v11p_sd15_lineart",
# "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
# "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
# "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
# "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
# "texnet": "/home/jyang/projects/ObjectReal/logs/train_texnet_deploy/checkpoint-55000/controlnet" # load and call
"texnet": "jingyangcarl/texnet",
}
def download_all_controlnet_weights() -> None:
for model_id in CONTROLNET_MODEL_IDS.values():
ControlNetModel.from_pretrained(model_id)
class Model:
def __init__(
self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny"
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.base_model_id = ""
self.task_name = ""
self.pipe = self.load_pipe(base_model_id, task_name)
self.pipe_base = StableDiffusionImg2ImgPipeline.from_pretrained(
'runwayml/stable-diffusion-v1-5', safety_checker=None, torch_dtype=torch.float16
).to(self.device)
self.preprocessor = Preprocessor()
# set up pipe_rgb2x
self.pipe_rgb2x = StableDiffusionAOVMatEstPipeline.from_pretrained(
"zheng95z/rgb-to-x",
torch_dtype=torch.float16,
).to(self.device)
self.pipe_rgb2x.scheduler = DDIMScheduler.from_config(
self.pipe_rgb2x.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
self.pipe_rgb2x.set_progress_bar_config(disable=True)
# setup blender
self.blender_path = '/tmp/blender-3.2.2-linux-x64/blender'
if not os.path.exists(self.blender_path):
print("Downloading Blender...")
subprocess.run(["wget", "https://download.blender.org/release/Blender3.2/blender-3.2.2-linux-x64.tar.xz", "-O", "/tmp/blender-3.2.2-linux-x64.tar.xz"], check=True)
subprocess.run(["tar", "-xf", "/tmp/blender-3.2.2-linux-x64.tar.xz", "-C", "/tmp"], check=True)
print("Blender downloaded and extracted.")
def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline:
if (
base_model_id == self.base_model_id
and task_name == self.task_name
and hasattr(self, "pipe")
and self.pipe is not None
):
return self.pipe
model_id = CONTROLNET_MODEL_IDS[task_name]
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
to_upload = False
if to_upload:
# confirm before uploading
confirm = input(f"Do you want to upload {model_id} to the hub? (y/n): ")
if confirm.lower() == "y":
controlnet.push_to_hub("jingyangcarl/texnet")
else:
print("Upload cancelled.")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to(self.device)
if self.device.type == "cuda":
import os
if os.environ.get("SPACES_ZERO_GPU", "0") == "1":
# when running on ZeroGPU, enable CPU offload
# pipe.enable_xformers_memory_efficient_attention() doens't work
# pipe.enable_model_cpu_offload()
pass
else:
pipe.enable_xformers_memory_efficient_attention()
torch.cuda.empty_cache()
gc.collect()
self.base_model_id = base_model_id
self.task_name = task_name
return pipe
def set_base_model(self, base_model_id: str) -> str:
if not base_model_id or base_model_id == self.base_model_id:
return self.base_model_id
del self.pipe
torch.cuda.empty_cache()
gc.collect()
try:
self.pipe = self.load_pipe(base_model_id, self.task_name)
except Exception: # noqa: BLE001
self.pipe = self.load_pipe(self.base_model_id, self.task_name)
return self.base_model_id
def load_controlnet_weight(self, task_name: str) -> None:
if task_name == self.task_name:
return
if self.pipe is not None and hasattr(self.pipe, "controlnet"):
del self.pipe.controlnet
torch.cuda.empty_cache()
gc.collect()
model_id = CONTROLNET_MODEL_IDS[task_name]
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
controlnet.to(self.device)
torch.cuda.empty_cache()
gc.collect()
self.pipe.controlnet = controlnet
self.task_name = task_name
def get_prompt(self, prompt: str, additional_prompt: str) -> str:
return additional_prompt if not prompt else f"{prompt}, {additional_prompt}"
# @spaces.GPU #[uncomment to use ZeroGPU]
@torch.autocast("cuda")
def run_pipe(
self,
prompt: str,
negative_prompt: str,
control_image: PIL.Image.Image,
num_images: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
generator = torch.Generator().manual_seed(seed)
# self.pipe.to(self.device)
return self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
num_inference_steps=num_steps,
generator=generator,
image=control_image,
).images
# @spaces.GPU #[uncomment to use ZeroGPU]
@torch.inference_mode()
def process_texnet(
self,
obj_name: str,
represented_image: np.ndarray | None, # not used
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
low_threshold: int,
high_threshold: int,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
self.preprocessor.load("texnet")
control_image = self.preprocessor(
image=image, low_threshold=low_threshold, high_threshold=high_threshold, image_resolution=image_resolution, output_type="pil"
)
self.load_controlnet_weight("texnet")
tex_coarse = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
# use img2img pipeline
self.pipe_backup = self.pipe
self.pipe = self.pipe_base
# refine
tex_fine = []
for result_coarse in tex_coarse:
# clean up GPU cache
torch.cuda.empty_cache()
gc.collect()
# masking
mask = (np.array(control_image).sum(axis=-1) == 0)[...,None]
image_masked = PIL.Image.fromarray(np.where(mask, control_image, result_coarse))
image_blurry = transforms.GaussianBlur(kernel_size=5, sigma=1)(image_masked)
result_fine = self.run_pipe(
# prompt=prompt,
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=image_blurry,
num_images=1,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)[0]
result_fine = PIL.Image.fromarray(np.where(mask, control_image, result_fine))
tex_fine.append(result_fine)
# restore the original pipe
self.pipe = self.pipe_backup
# use rgb2x for now for generating the texture
def rgb2x(
pipeline,
photo,
inference_step = 50,
num_samples = 1,
):
generator = torch.Generator(device="cuda").manual_seed(seed)
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
old_height = photo.shape[1]
old_width = photo.shape[2]
new_height = old_height
new_width = old_width
radio = old_height / old_width
max_side = 1000
if old_height > old_width:
new_height = max_side
new_width = int(new_height / radio)
else:
new_width = max_side
new_height = int(new_width * radio)
if new_width % 8 != 0 or new_height % 8 != 0:
new_width = new_width // 8 * 8
new_height = new_height // 8 * 8
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
prompts = {
"albedo": "Albedo (diffuse basecolor)",
"normal": "Camera-space Normal",
"roughness": "Roughness",
"metallic": "Metallicness",
"irradiance": "Irradiance (diffuse lighting)",
}
return_list = []
for i in tqdm(range(num_samples), desc="Running Pipeline", leave=False):
for aov_name in required_aovs:
prompt = prompts[aov_name]
generated_image = pipeline(
prompt=prompt,
photo=photo,
num_inference_steps=inference_step,
height=new_height,
width=new_width,
generator=generator,
required_aovs=[aov_name],
).images[0][0]
generated_image = torchvision.transforms.Resize(
(old_height, old_width)
)(generated_image)
# generated_image = (generated_image, f"Generated {aov_name} {i}")
# generated_image = (generated_image, f"{aov_name}")
return_list.append(generated_image)
return photo, return_list, prompts
# Load rgb2x pipeline
_, preds, prompts = rgb2x(self.pipe_rgb2x, torchvision.transforms.PILToTensor()(tex_fine[0]).to(self.pipe.device), inference_step=num_steps, num_samples=num_images)
base_color_path = image_to_temp_path(tex_fine[0].rotate(90), "base_color")
normal_map_path = image_to_temp_path(preds[0].rotate(90), "normal_map")
roughness_path = image_to_temp_path(preds[1].rotate(90), "roughness")
metallic_path = image_to_temp_path(preds[2].rotate(90), "metallic")
prompt_nospace = prompt.replace(' ', '_')
current_timecode = time.strftime("%Y%m%d_%H%M%S")
# output_blend_path = os.path.join(os.getcwd(), "output", f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
output_blend_path = os.path.join(tempfile.mkdtemp(), f"{obj_name}_{prompt_nospace}_{current_timecode}.blend") # replace with desired output path
os.makedirs(os.path.dirname(output_blend_path), exist_ok=True)
def run_blend_generation(
blender_path,
generate_script_path,
obj_path,
base_color_path,
normal_map_path,
roughness_path,
metallic_path,
output_blend
):
cmd = [
blender_path, "--background", "--python", generate_script_path, "--",
obj_path, base_color_path, normal_map_path, roughness_path, metallic_path, output_blend
]
subprocess.run(cmd, check=True)
# check if the blender_path exists, if not download
run_blend_generation(
blender_path=self.blender_path,
generate_script_path="rgb2x/generate_blend.py",
obj_path=f"examples/{obj_name}/mesh.obj", # replace with actual mesh path
base_color_path=base_color_path,
normal_map_path=normal_map_path,
roughness_path=roughness_path,
metallic_path=metallic_path,
output_blend=output_blend_path # replace with desired output path
)
# gallary
return [*tex_fine], [preds[1]], [preds[2]], [preds[3]], [output_blend_path]
# @spaces.GPU #[uncomment to use ZeroGPU]
@torch.inference_mode()
def process_canny(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
low_threshold: int,
high_threshold: int,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
self.preprocessor.load("Canny")
control_image = self.preprocessor(
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
)
self.load_controlnet_weight("Canny")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_mlsd(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
value_threshold: float,
distance_threshold: float,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
self.preprocessor.load("MLSD")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
thr_v=value_threshold,
thr_d=distance_threshold,
)
self.load_controlnet_weight("MLSD")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_scribble(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
elif preprocessor_name == "HED":
self.preprocessor.load(preprocessor_name)
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
scribble=False,
)
elif preprocessor_name == "PidiNet":
self.preprocessor.load(preprocessor_name)
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
safe=False,
)
self.load_controlnet_weight("scribble")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_scribble_interactive(
self,
image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
if image_and_mask is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
image = 255 - image_and_mask["composite"] # type: ignore
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
self.load_controlnet_weight("scribble")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_softedge(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
elif preprocessor_name in ["HED", "HED safe"]:
safe = "safe" in preprocessor_name
self.preprocessor.load("HED")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
scribble=safe,
)
elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
safe = "safe" in preprocessor_name
self.preprocessor.load("PidiNet")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
safe=safe,
)
else:
raise ValueError
self.load_controlnet_weight("softedge")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_openpose(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
else:
self.preprocessor.load("Openpose")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
hand_and_face=True,
)
self.load_controlnet_weight("Openpose")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_segmentation(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
else:
self.preprocessor.load(preprocessor_name)
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
)
self.load_controlnet_weight("segmentation")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_depth(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
else:
self.preprocessor.load(preprocessor_name)
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
)
self.load_controlnet_weight("depth")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_normal(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
else:
self.preprocessor.load("NormalBae")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
)
self.load_controlnet_weight("NormalBae")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_lineart(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
preprocess_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name in ["None", "None (anime)"]:
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
elif preprocessor_name in ["Lineart", "Lineart coarse"]:
coarse = "coarse" in preprocessor_name
self.preprocessor.load("Lineart")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
coarse=coarse,
)
elif preprocessor_name == "Lineart (anime)":
self.preprocessor.load("LineartAnime")
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
)
if "anime" in preprocessor_name:
self.load_controlnet_weight("lineart_anime")
else:
self.load_controlnet_weight("lineart")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_shuffle(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
preprocessor_name: str,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
if preprocessor_name == "None":
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
else:
self.preprocessor.load(preprocessor_name)
control_image = self.preprocessor(
image=image,
image_resolution=image_resolution,
)
self.load_controlnet_weight("shuffle")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]
@torch.inference_mode()
def process_ip2p(
self,
image: np.ndarray,
prompt: str,
additional_prompt: str,
negative_prompt: str,
num_images: int,
image_resolution: int,
num_steps: int,
guidance_scale: float,
seed: int,
) -> list[PIL.Image.Image]:
if image is None:
raise ValueError
if image_resolution > MAX_IMAGE_RESOLUTION:
raise ValueError
if num_images > MAX_NUM_IMAGES:
raise ValueError
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
control_image = PIL.Image.fromarray(image)
self.load_controlnet_weight("ip2p")
results = self.run_pipe(
prompt=self.get_prompt(prompt, additional_prompt),
negative_prompt=negative_prompt,
control_image=control_image,
num_images=num_images,
num_steps=num_steps,
guidance_scale=guidance_scale,
seed=seed,
)
return [control_image, *results]