File size: 4,225 Bytes
96ef3e3
 
7b9c473
96ef3e3
559bd5d
96ef3e3
 
 
 
 
 
 
 
 
 
54a4474
96ef3e3
53040c7
4de6b2d
7b9c473
5e0b0a9
96ef3e3
 
 
 
 
 
 
 
 
 
 
29d319f
96ef3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9c473
96ef3e3
53040c7
96ef3e3
 
 
 
 
 
7b9c473
96ef3e3
 
 
 
 
53040c7
96ef3e3
 
 
 
 
 
7b9c473
96ef3e3
 
 
 
 
53040c7
96ef3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
os.system('git clone https://github.com/tencent-ailab/IP-Adapter.git')
os.system('wget https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.bin')
os.system('mv IP-Adapter IP_Adapter')
os.system('ls IP_Adapter/ip_adapter')
import gradio as gr
import torch
from PIL import Image
from diffusers import (
    StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, 
    StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL
)
from IP_Adapter.ip_adapter import IPAdapter

# Paths and device
base_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_repo="InvokeAI/ip_adapter_sd_image_encoder"
image_encoder_path = "IP_Adapter/ip_adapter/models/image_encoder/"
ip_ckpt = "ip-adapter_sd15.bin"
device = "cuda"  # or "cuda" if using GPU

# VAE and scheduler
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path)#.to(dtype=torch.float16)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

def generate_variations(upload_img):
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None,
        #torch_dtype=torch.float16
    )
    ip_model = IPAdapter(pipe, image_encoder_repo, ip_ckpt, device)
    images = ip_model.generate(pil_image=upload_img, num_samples=4, num_inference_steps=50, seed=42)
    return image_grid(images, 1, 4)

def generate_img2img(base_img, guide_img):
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        base_model_path,
        #torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
    ip_model = IPAdapter(pipe, image_encoder_repo, ip_ckpt, device)
    images = ip_model.generate(pil_image=base_img, image=guide_img, strength=0.6, num_samples=4, num_inference_steps=50, seed=42)
    return image_grid(images, 1, 4)

def generate_inpaint(input_img, masked_img, mask_img):
    pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
        base_model_path,
        #torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
    ip_model = IPAdapter(pipe, image_encoder_repo, ip_ckpt, device)
    images = ip_model.generate(pil_image=input_img, image=masked_img, mask_image=mask_img,
                               strength=0.7, num_samples=4, num_inference_steps=50, seed=42)
    return image_grid(images, 1, 4)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# IP-Adapter Image Manipulation Demo")

    with gr.Tab("Image Variations"):
        with gr.Row():
            img_input = gr.Image(type="pil", label="Upload Image")
            img_output = gr.Image(label="Generated Variations")
        img_btn = gr.Button("Generate Variations")
        img_btn.click(fn=generate_variations, inputs=img_input, outputs=img_output)

    with gr.Tab("Image-to-Image"):
        with gr.Row():
            img1 = gr.Image(type="pil", label="Base Image")
            img2 = gr.Image(type="pil", label="Guide Image")
            img2_out = gr.Image(label="Output")
        btn2 = gr.Button("Generate Img2Img")
        btn2.click(fn=generate_img2img, inputs=[img1, img2], outputs=img2_out)

    with gr.Tab("Inpainting"):
        with gr.Row():
            inpaint_img = gr.Image(type="pil", label="Input Image")
            masked = gr.Image(type="pil", label="Masked Image")
            mask = gr.Image(type="pil", label="Mask")
            inpaint_out = gr.Image(label="Inpainted")
        btn3 = gr.Button("Generate Inpainting")
        btn3.click(fn=generate_inpaint, inputs=[inpaint_img, masked, mask], outputs=inpaint_out)

demo.launch()