Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| import os | |
| import re | |
| from PIL import Image, ImageDraw | |
| import cv2 | |
| # | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection | |
| import torch.nn as nn | |
| from inference.manganinjia_pipeline import MangaNinjiaPipeline | |
| from diffusers import ( | |
| ControlNetModel, | |
| DiffusionPipeline, | |
| DDIMScheduler, | |
| AutoencoderKL, | |
| ) | |
| from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl | |
| from src.models.unet_2d_condition import UNet2DConditionModel | |
| from src.models.refunet_2d_condition import RefUNet2DConditionModel | |
| from src.point_network import PointNet | |
| from src.annotator.lineart import BatchLineartDetector | |
| val_configs = OmegaConf.load('./configs/inference.yaml') | |
| # download the checkpoints | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| os.makedirs("checkpoints", exist_ok=True) | |
| # List of subdirectories to create inside "checkpoints" | |
| subfolders = [ | |
| "StableDiffusion", | |
| "models", | |
| "MangaNinjia" | |
| ] | |
| # Create each subdirectory | |
| for subfolder in subfolders: | |
| os.makedirs(os.path.join("checkpoints", subfolder), exist_ok=True) | |
| # List of subdirectories to create inside "models" | |
| models_subfolders = [ | |
| "clip-vit-large-patch14", | |
| "control_v11p_sd15_lineart", | |
| "Annotators" | |
| ] | |
| # Create each subdirectory | |
| for subfolder in models_subfolders: | |
| os.makedirs(os.path.join("checkpoints/models", subfolder), exist_ok=True) | |
| snapshot_download( | |
| repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| local_dir = "./checkpoints/StableDiffusion" | |
| ) | |
| snapshot_download( | |
| repo_id = "openai/clip-vit-large-patch14", | |
| local_dir = "./checkpoints/models/clip-vit-large-patch14" | |
| ) | |
| snapshot_download( | |
| repo_id = "lllyasviel/control_v11p_sd15_lineart", | |
| local_dir = "./checkpoints/models/control_v11p_sd15_lineart" | |
| ) | |
| hf_hub_download( | |
| repo_id = "lllyasviel/Annotators", | |
| filename = "sk_model.pth", | |
| local_dir = "./checkpoints/models/Annotators" | |
| ) | |
| snapshot_download( | |
| repo_id = "Johanan0528/MangaNinjia", | |
| local_dir = "./checkpoints/MangaNinjia" | |
| ) | |
| # === load the checkpoint === | |
| pretrained_model_name_or_path = val_configs.model_path.pretrained_model_name_or_path | |
| refnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path | |
| controlnet_clip_vision_encoder_path = val_configs.model_path.clip_vision_encoder_path | |
| controlnet_model_name_or_path = val_configs.model_path.controlnet_model_name | |
| annotator_ckpts_path = val_configs.model_path.annotator_ckpts_path | |
| output_root = val_configs.inference_config.output_path | |
| device = val_configs.inference_config.device | |
| preprocessor = BatchLineartDetector(annotator_ckpts_path) | |
| in_channels_reference_unet = 4 | |
| in_channels_denoising_unet = 4 | |
| in_channels_controlnet = 4 | |
| noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path,subfolder='scheduler') | |
| vae = AutoencoderKL.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder='vae' | |
| ) | |
| denoising_unet = UNet2DConditionModel.from_pretrained( | |
| pretrained_model_name_or_path,subfolder="unet", | |
| in_channels=in_channels_denoising_unet, | |
| low_cpu_mem_usage=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| reference_unet = RefUNet2DConditionModel.from_pretrained( | |
| pretrained_model_name_or_path,subfolder="unet", | |
| in_channels=in_channels_reference_unet, | |
| low_cpu_mem_usage=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| refnet_tokenizer = CLIPTokenizer.from_pretrained(refnet_clip_vision_encoder_path) | |
| refnet_text_encoder = CLIPTextModel.from_pretrained(refnet_clip_vision_encoder_path) | |
| refnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(refnet_clip_vision_encoder_path) | |
| controlnet = ControlNetModel.from_pretrained( | |
| controlnet_model_name_or_path, | |
| in_channels=in_channels_controlnet, | |
| low_cpu_mem_usage=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| controlnet_tokenizer = CLIPTokenizer.from_pretrained(controlnet_clip_vision_encoder_path) | |
| controlnet_text_encoder = CLIPTextModel.from_pretrained(controlnet_clip_vision_encoder_path) | |
| controlnet_image_enc = CLIPVisionModelWithProjection.from_pretrained(controlnet_clip_vision_encoder_path) | |
| point_net=PointNet() | |
| reference_control_writer = ReferenceAttentionControl( | |
| reference_unet, | |
| do_classifier_free_guidance=False, | |
| mode="write", | |
| fusion_blocks="full", | |
| ) | |
| reference_control_reader = ReferenceAttentionControl( | |
| denoising_unet, | |
| do_classifier_free_guidance=False, | |
| mode="read", | |
| fusion_blocks="full", | |
| ) | |
| controlnet.load_state_dict( | |
| torch.load(val_configs.model_path.manga_control_model_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| point_net.load_state_dict( | |
| torch.load(val_configs.model_path.point_net_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| reference_unet.load_state_dict( | |
| torch.load(val_configs.model_path.manga_reference_model_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| denoising_unet.load_state_dict( | |
| torch.load(val_configs.model_path.manga_main_model_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| pipe = MangaNinjiaPipeline( | |
| reference_unet=reference_unet, | |
| controlnet=controlnet, | |
| denoising_unet=denoising_unet, | |
| vae=vae, | |
| refnet_tokenizer=refnet_tokenizer, | |
| refnet_text_encoder=refnet_text_encoder, | |
| refnet_image_encoder=refnet_image_enc, | |
| controlnet_tokenizer=controlnet_tokenizer, | |
| controlnet_text_encoder=controlnet_text_encoder, | |
| controlnet_image_encoder=controlnet_image_enc, | |
| scheduler=noise_scheduler, | |
| point_net=point_net | |
| ) | |
| pipe = pipe.to(torch.device(device)) | |
| def string_to_np_array(coord_string): | |
| coord_string = coord_string.strip('[]') | |
| coords = re.findall(r'\d+', coord_string) | |
| coords = list(map(int, coords)) | |
| coord_array = np.array(coords).reshape(-1, 2) | |
| return coord_array | |
| def infer_single(is_lineart, ref_image, target_image, output_coords_ref, output_coords_base, seed = -1, num_inference_steps=20, guidance_scale_ref = 9, guidance_scale_point =15 ): | |
| """ | |
| mask: 0/1 1-channel np.array | |
| image: rgb np.array | |
| """ | |
| generator = torch.cuda.manual_seed(seed) | |
| matrix1 = np.zeros((512, 512), dtype=np.uint8) | |
| matrix2 = np.zeros((512, 512), dtype=np.uint8) | |
| output_coords_ref = string_to_np_array(output_coords_ref) | |
| output_coords_base = string_to_np_array(output_coords_base) | |
| for index, (coords_ref,coords_base) in enumerate(zip(output_coords_ref,output_coords_base)): | |
| y1, x1 = coords_ref | |
| y2, x2 = coords_base | |
| matrix1[y1, x1] = index + 1 | |
| matrix2[y2, x2] = index + 1 | |
| point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0) | |
| point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0) | |
| preprocessor.to(device,dtype=torch.float32) | |
| pipe_out = pipe( | |
| is_lineart, | |
| ref_image, | |
| target_image, | |
| target_image, | |
| denosing_steps=num_inference_steps, | |
| processing_res=512, | |
| match_input_res=True, | |
| batch_size=1, | |
| show_progress_bar=True, | |
| guidance_scale_ref=guidance_scale_ref, | |
| guidance_scale_point=guidance_scale_point, | |
| preprocessor=preprocessor, | |
| generator=generator, | |
| point_ref=point_ref, | |
| point_main=point_main, | |
| ) | |
| return pipe_out | |
| def inference_single_image(ref_image, | |
| tar_image, | |
| ddim_steps, | |
| scale_ref, | |
| scale_point, | |
| seed, | |
| output_coords1, | |
| output_coords2, | |
| is_lineart | |
| ): | |
| if seed == -1: | |
| seed = np.random.randint(10000) | |
| pipe_out = infer_single(is_lineart, ref_image, tar_image, output_coords_ref=output_coords1, output_coords_base=output_coords2,seed=seed ,num_inference_steps=ddim_steps, guidance_scale_ref = scale_ref, guidance_scale_point = scale_point | |
| ) | |
| return pipe_out | |
| clicked_points_img1 = [] | |
| clicked_points_img2 = [] | |
| current_img_idx = 0 | |
| max_clicks = 14 | |
| point_size = 8 | |
| colors = [(255, 0, 0), (0, 255, 0)] | |
| # Process images: resizing them to 512x512 | |
| def process_image(ref, base): | |
| ref_resized = cv2.resize(ref, (512, 512)) # Note OpenCV resize order is (width, height) | |
| base_resized = cv2.resize(base, (512, 512)) | |
| return ref_resized, base_resized | |
| # Convert string to numpy array of coordinates | |
| def string_to_np_array(coord_string): | |
| coord_string = coord_string.strip('[]') | |
| coords = re.findall(r'\d+', coord_string) | |
| coords = list(map(int, coords)) | |
| coord_array = np.array(coords).reshape(-1, 2) | |
| return coord_array | |
| # Function to handle click events | |
| def get_select_coords(img1, img2, evt: gr.SelectData): | |
| global clicked_points_img1, clicked_points_img2, current_img_idx | |
| click_coords = (evt.index[1], evt.index[0]) | |
| if current_img_idx == 0: | |
| clicked_points_img1.append(click_coords) | |
| if len(clicked_points_img1) > max_clicks: | |
| clicked_points_img1 = [] | |
| current_img = img1 | |
| clicked_points = clicked_points_img1 | |
| else: | |
| clicked_points_img2.append(click_coords) | |
| if len(clicked_points_img2) > max_clicks: | |
| clicked_points_img2 = [] | |
| current_img = img2 | |
| clicked_points = clicked_points_img2 | |
| current_img_idx = 1 - current_img_idx | |
| img_pil = Image.fromarray(current_img.astype('uint8')) | |
| draw = ImageDraw.Draw(img_pil) | |
| for idx, point in enumerate(clicked_points): | |
| x, y = point | |
| color = colors[current_img_idx] | |
| for dx in range(-point_size, point_size + 1): | |
| for dy in range(-point_size, point_size + 1): | |
| if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]: | |
| draw.point((y+dy, x+dx), fill=color) | |
| img_out = np.array(img_pil) | |
| coord_array = np.array([(x, y) for x, y in clicked_points]) | |
| return img_out, coord_array | |
| # Function to clear the clicked points | |
| def undo_last_point(ref, base): | |
| global clicked_points_img1, clicked_points_img2, current_img_idx | |
| current_img_idx=1-current_img_idx | |
| if current_img_idx == 0 and clicked_points_img1: | |
| clicked_points_img1.pop() # Undo last point in ref | |
| elif current_img_idx == 1 and clicked_points_img2: | |
| clicked_points_img2.pop() # Undo last point in base | |
| # After removing the last point, redraw the image without it | |
| if current_img_idx == 0: | |
| current_img = ref | |
| current_img_other = base | |
| clicked_points = clicked_points_img1 | |
| clicked_points_other = clicked_points_img2 | |
| else: | |
| current_img = base | |
| current_img_other = ref | |
| clicked_points = clicked_points_img2 | |
| clicked_points_other = clicked_points_img1 | |
| # Redraw the image without the last point | |
| img_pil = Image.fromarray(current_img.astype('uint8')) | |
| draw = ImageDraw.Draw(img_pil) | |
| for idx, point in enumerate(clicked_points): | |
| x, y = point | |
| color = colors[current_img_idx] | |
| for dx in range(-point_size, point_size + 1): | |
| for dy in range(-point_size, point_size + 1): | |
| if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]: | |
| draw.point((y+dy, x+dx), fill=color) | |
| img_out = np.array(img_pil) | |
| img_pil_other = Image.fromarray(current_img_other.astype('uint8'),) | |
| draw_other = ImageDraw.Draw(img_pil_other) | |
| for idx, point in enumerate(clicked_points_other): | |
| x, y = point | |
| color = colors[1-current_img_idx] | |
| for dx in range(-point_size, point_size + 1): | |
| for dy in range(-point_size, point_size + 1): | |
| if 0 <= y + dy < img_pil.size[0] and 0 <= x + dx < img_pil.size[1]: | |
| draw_other.point((y+dy, x+dx), fill=color) | |
| img_out_other = np.array(img_pil_other) | |
| coord_array = np.array([(x, y) for x, y in clicked_points]) | |
| # Return the updated image and coordinates as text | |
| updated_coords = str(coord_array.tolist()) | |
| # If current_img_idx is 0, it means we are working with ref, so return for ref | |
| if current_img_idx == 0: | |
| coord_array2 = np.array([(x, y) for x, y in clicked_points_img2]) | |
| updated_coords2 = str(coord_array2.tolist()) | |
| return img_out, updated_coords, img_out_other, updated_coords2 # for ref image | |
| else: | |
| coord_array1 = np.array([(x, y) for x, y in clicked_points_img1]) | |
| updated_coords1 = str(coord_array1.tolist()) | |
| return img_out_other, updated_coords1, img_out, updated_coords # for base image | |
| # Main function to run the image processing | |
| def run_local(ref, base, *args, progress=gr.Progress(track_tqdm=True)): | |
| image = Image.fromarray(base) | |
| ref_image = Image.fromarray(ref) | |
| pipe_out = inference_single_image(ref_image.copy(), image.copy(), *args) | |
| to_save_dict = pipe_out.to_save_dict | |
| to_save_dict['edit2'] = pipe_out.img_pil | |
| return [to_save_dict['edit2'], to_save_dict['edge2_black']] | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("# MangaNinja: Line Art Colorization with Precise Reference Following") | |
| with gr.Row(): | |
| baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", columns=1, height=768) | |
| with gr.Accordion("Advanced Option", open=True): | |
| num_samples = 1 | |
| ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) | |
| scale_ref = gr.Slider(label="Guidance of ref", minimum=0, maximum=30.0, value=9, step=0.1) | |
| scale_point = gr.Slider(label="Guidance of points", minimum=0, maximum=30.0, value=15, step=0.1) | |
| is_lineart = gr.Checkbox(label="Input is lineart", value=False) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=-1) | |
| gr.Markdown("### Tutorial") | |
| gr.Markdown("1. Upload the reference image and target image. Note that for the target image, there are two modes: you can upload an RGB image, and the model will automatically extract the line art; or you can directly upload the line art by checking the 'input is lineart' option.") | |
| gr.Markdown("2. Click 'Process Images' to resize the images to 512*512 resolution.") | |
| gr.Markdown("3. (Optional) **Starting from the reference image**, **alternately** click on the reference and target images in sequence to define matching points. Use 'Undo' to revert the last action.") | |
| gr.Markdown("4. Click 'Generate' to produce the result.") | |
| gr.Markdown("# Upload the reference image and target image") | |
| with gr.Row(): | |
| ref = gr.Image(label="Reference Image",) | |
| base = gr.Image(label="Target Image",) | |
| gr.Button("Process Images").click(process_image, inputs=[ref, base], outputs=[ref, base]) | |
| with gr.Row(): | |
| output_img1 = gr.Image(label="Reference Output") | |
| output_coords1 = gr.Textbox(lines=2, label="Clicked Coordinates Image 1 (npy format)") | |
| output_img2 = gr.Image(label="Base Output") | |
| output_coords2 = gr.Textbox(lines=2, label="Clicked Coordinates Image 2 (npy format)") | |
| # Image click select functions | |
| ref.select(get_select_coords, [ref, base], [output_img1, output_coords1]) | |
| base.select(get_select_coords, [ref, base], [output_img2, output_coords2]) | |
| # Undo button | |
| undo_button = gr.Button("Undo") | |
| undo_button.click(undo_last_point, inputs=[ref, base], outputs=[output_img1, output_coords1, output_img2, output_coords2]) | |
| run_local_button = gr.Button("Generate") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ['test_cases/hz0.png', 'test_cases/manga_target_examples/target_1.jpg'], | |
| ['test_cases/more_cases/az0.png', 'test_cases/manga_target_examples/target_2.jpg'], | |
| ['test_cases/more_cases/hi0.png', 'test_cases/manga_target_examples/target_3.jpg'], | |
| ['test_cases/more_cases/kn0.jpg', 'test_cases/manga_target_examples/target_4.jpg'], | |
| ['test_cases/more_cases/rk0.jpg', 'test_cases/manga_target_examples/target_5.jpg'], | |
| ], | |
| inputs=[ref, base], | |
| cache_examples=False, | |
| examples_per_page=100 | |
| ) | |
| run_local_button.click(fn=run_local, | |
| inputs=[ref, | |
| base, | |
| ddim_steps, | |
| scale_ref, | |
| scale_point, | |
| seed, | |
| output_coords1, | |
| output_coords2, | |
| is_lineart | |
| ], | |
| outputs=[baseline_gallery] | |
| ) | |
| demo.launch(show_api=False, show_error=True, ssr_mode=True) | |