Spaces:
Running
Running
import spaces | |
import random | |
import torch | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
from huggingface_hub import snapshot_download | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor, pipeline | |
from diffusers.utils import load_image | |
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline | |
from kolors.models.modeling_chatglm import ChatGLMModel | |
from kolors.models.tokenization_chatglm import ChatGLMTokenizer | |
from kolors.models.controlnet import ControlNetModel | |
from diffusers import AutoencoderKL | |
from kolors.models.unet_2d_condition import UNet2DConditionModel | |
from diffusers import EulerDiscreteScheduler | |
from PIL import Image, ImageDraw, ImageFont | |
from annotator.midas import MidasDetector | |
from annotator.dwpose import DWposeDetector | |
from annotator.util import resize_image, HWC3 | |
import os | |
device = "cuda" | |
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors") | |
ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth") | |
ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny") | |
ckpt_dir_pose = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Pose") | |
# Add translation pipeline | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device) | |
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') | |
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device) | |
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") | |
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device) | |
controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device) | |
controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device) | |
controlnet_pose = ControlNetModel.from_pretrained(f"{ckpt_dir_pose}", revision=None).half().to(device) | |
pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline( | |
vae=vae, | |
controlnet=controlnet_depth, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
force_zeros_for_empty_prompt=False | |
) | |
pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline( | |
vae=vae, | |
controlnet=controlnet_canny, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
force_zeros_for_empty_prompt=False | |
) | |
pipe_pose = StableDiffusionXLControlNetImg2ImgPipeline( | |
vae=vae, | |
controlnet=controlnet_pose, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
force_zeros_for_empty_prompt=False | |
) | |
def translate_korean_to_english(text): | |
if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): # Check if Korean characters are present | |
translated = translator(text, max_length=512)[0]['translation_text'] | |
return translated | |
return text | |
def process_canny_condition(image, canny_threods=[100,200]): | |
np_image = image.copy() | |
np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1]) | |
np_image = np_image[:, :, None] | |
np_image = np.concatenate([np_image, np_image, np_image], axis=2) | |
np_image = HWC3(np_image) | |
return Image.fromarray(np_image) | |
model_midas = MidasDetector() | |
def process_depth_condition_midas(img, res = 1024): | |
h,w,_ = img.shape | |
img = resize_image(HWC3(img), res) | |
result = HWC3(model_midas(img)) | |
result = cv2.resize(result, (w,h)) | |
return Image.fromarray(result) | |
model_dwpose = DWposeDetector() | |
def process_dwpose_condition(image, res=1024): | |
h,w,_ = image.shape | |
img = resize_image(HWC3(image), res) | |
out_res, out_img = model_dwpose(image) | |
result = HWC3(out_img) | |
result = cv2.resize(result, (w,h)) | |
return Image.fromarray(result) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
def infer_depth(prompt, | |
image = None, | |
negative_prompt = "nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights", | |
seed = 397886929, | |
randomize_seed = False, | |
guidance_scale = 6.0, | |
num_inference_steps = 50, | |
controlnet_conditioning_scale = 0.7, | |
control_guidance_end = 0.9, | |
strength = 1.0 | |
): | |
prompt = translate_korean_to_english(prompt) | |
negative_prompt = translate_korean_to_english(negative_prompt) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
init_image = resize_image(image, MAX_IMAGE_SIZE) | |
pipe = pipe_depth.to("cuda") | |
condi_img = process_depth_condition_midas(np.array(init_image), MAX_IMAGE_SIZE) | |
image = pipe( | |
prompt=prompt, | |
image=init_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
control_guidance_end=control_guidance_end, | |
strength=strength, | |
control_image=condi_img, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
).images[0] | |
return [condi_img, image], seed | |
def infer_canny(prompt, | |
image = None, | |
negative_prompt = "nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights", | |
seed = 397886929, | |
randomize_seed = False, | |
guidance_scale = 6.0, | |
num_inference_steps = 50, | |
controlnet_conditioning_scale = 0.7, | |
control_guidance_end = 0.9, | |
strength = 1.0 | |
): | |
prompt = translate_korean_to_english(prompt) | |
negative_prompt = translate_korean_to_english(negative_prompt) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
init_image = resize_image(image, MAX_IMAGE_SIZE) | |
pipe = pipe_canny.to("cuda") | |
condi_img = process_canny_condition(np.array(init_image)) | |
image = pipe( | |
prompt=prompt, | |
image=init_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
control_guidance_end=control_guidance_end, | |
strength=strength, | |
control_image=condi_img, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
).images[0] | |
return [condi_img, image], seed | |
def infer_pose(prompt, | |
image = None, | |
negative_prompt = "nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights", | |
seed = 66, | |
randomize_seed = False, | |
guidance_scale = 6.0, | |
num_inference_steps = 50, | |
controlnet_conditioning_scale = 0.7, | |
control_guidance_end = 0.9, | |
strength = 1.0 | |
): | |
prompt = translate_korean_to_english(prompt) | |
negative_prompt = translate_korean_to_english(negative_prompt) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
init_image = resize_image(image, MAX_IMAGE_SIZE) | |
pipe = pipe_pose.to("cuda") | |
condi_img = process_dwpose_condition(np.array(init_image), MAX_IMAGE_SIZE) | |
image = pipe( | |
prompt=prompt, | |
image=init_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
control_guidance_end=control_guidance_end, | |
strength=strength, | |
control_image=condi_img, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
).images[0] | |
return [condi_img, image], seed | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
def load_description(fp): | |
with open(fp, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
# Add the text_to_image function | |
def text_to_image(text, size, position): | |
width, height = 1024, 576 | |
image = Image.new("RGB", (width, height), "white") | |
draw = ImageDraw.Draw(image) | |
font_files = ["Arial_Unicode.ttf"] | |
font = None | |
for font_file in font_files: | |
font_path = os.path.join(os.path.dirname(__file__), font_file) | |
if os.path.exists(font_path): | |
try: | |
font = ImageFont.truetype(font_path, size=size) | |
print(f"Using font: {font_file}") | |
break | |
except IOError: | |
print(f"Error loading font: {font_file}") | |
if font is None: | |
print("No suitable font found. Using default font.") | |
font = ImageFont.load_default() | |
lines = text.split('\n') | |
max_line_width = 0 | |
total_height = 0 | |
line_heights = [] | |
for line in lines: | |
left, top, right, bottom = draw.textbbox((0, 0), line, font=font) | |
line_width = right - left | |
line_height = bottom - top | |
line_heights.append(line_height) | |
max_line_width = max(max_line_width, line_width) | |
total_height += line_height | |
position_mapping = { | |
"top-left": (10, 10), | |
"top-center": ((width - max_line_width) / 2, 10), | |
"top-right": (width - max_line_width - 10, 10), | |
"middle-left": (10, (height - total_height) / 2), | |
"middle-center": ((width - max_line_width) / 2, (height - total_height) / 2), | |
"middle-right": (width - max_line_width - 10, (height - total_height) / 2), | |
"bottom-left": (10, height - total_height - 10), | |
"bottom-center": ((width - max_line_width) / 2, height - total_height - 10), | |
"bottom-right": (width - max_line_width - 10, height - total_height - 10), | |
} | |
x, y = position_mapping.get(position, ((width - max_line_width) / 2, height - total_height - 10)) | |
for i, line in enumerate(lines): | |
draw.text((x, y), line, fill="black", font=font) | |
y += line_heights[i] | |
return image | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as Kolors: | |
with gr.Row(): | |
with gr.Column(elem_id="col-left"): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt", | |
lines=2 | |
) | |
with gr.Row(): | |
image_input_type = gr.Radio(["Upload Image", "Generate Text Image"], label="Input Type", value="Upload Image") | |
with gr.Row(): | |
image = gr.Image(label="Image", type="pil", visible=True) | |
with gr.Column(visible=False) as text_image_inputs: | |
text_input = gr.Textbox(label="Enter Text", lines=5, placeholder="Type your text here...") | |
font_size = gr.Radio([48, 72, 96, 144], label="Font Size", value=72) | |
text_position = gr.Dropdown( | |
["top-left", "top-center", "top-right", "middle-left", "middle-center", "middle-right", "bottom-left", "bottom-center", "bottom-right"], | |
label="Text Position", | |
value="middle-center" | |
) | |
generate_text_image = gr.Button("Generate Text Image") | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative prompt", | |
placeholder="Enter a negative prompt", | |
visible=True, | |
value="nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights" | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=6.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=10, | |
maximum=50, | |
step=1, | |
value=30, | |
) | |
with gr.Row(): | |
controlnet_conditioning_scale = gr.Slider( | |
label="Controlnet Conditioning Scale", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.7, | |
) | |
control_guidance_end = gr.Slider( | |
label="Control Guidance End", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.9, | |
) | |
with gr.Row(): | |
strength = gr.Slider( | |
label="Strength", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
) | |
with gr.Row(): | |
canny_button = gr.Button("Canny", elem_id="button") | |
depth_button = gr.Button("Depth", elem_id="button") | |
pose_button = gr.Button("Pose", elem_id="button") | |
with gr.Column(elem_id="col-right"): | |
result = gr.Gallery(label="Result", show_label=False, columns=2) | |
seed_used = gr.Number(label="Seed Used") | |
def toggle_image_input(choice): | |
return { | |
image: gr.update(visible=choice == "Upload Image"), | |
text_image_inputs: gr.update(visible=choice == "Generate Text Image") | |
} | |
image_input_type.change(toggle_image_input, image_input_type, [image, text_image_inputs]) | |
def generate_and_use_text_image(text, size, position): | |
text_image = text_to_image(text, size, position) | |
return text_image | |
generate_text_image.click( | |
generate_and_use_text_image, | |
inputs=[text_input, font_size, text_position], | |
outputs=image | |
) | |
with gr.Row(): | |
gr.Examples( | |
fn = infer_canny, | |
examples = canny_examples, | |
inputs = [prompt, image], | |
outputs = [result, seed_used], | |
label = "Canny" | |
) | |
with gr.Row(): | |
gr.Examples( | |
fn = infer_depth, | |
examples = depth_examples, | |
inputs = [prompt, image], | |
outputs = [result, seed_used], | |
label = "Depth" | |
) | |
with gr.Row(): | |
gr.Examples( | |
fn = infer_pose, | |
examples = pose_examples, | |
inputs = [prompt, image], | |
outputs = [result, seed_used], | |
label = "Pose" | |
) | |
canny_button.click( | |
fn = infer_canny, | |
inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength], | |
outputs = [result, seed_used] | |
) | |
depth_button.click( | |
fn = infer_depth, | |
inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength], | |
outputs = [result, seed_used] | |
) | |
pose_button.click( | |
fn = infer_pose, | |
inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength], | |
outputs = [result, seed_used] | |
) | |
Kolors.queue().launch(debug=True) |