File size: 4,709 Bytes
44fe76b
df100ff
44fe76b
df100ff
44fe76b
df100ff
44fe76b
 
 
df100ff
44fe76b
 
df100ff
44fe76b
df100ff
 
 
44fe76b
df100ff
 
 
44fe76b
df100ff
 
 
 
 
 
 
 
44fe76b
 
 
 
 
 
 
 
df100ff
 
 
 
44fe76b
dbfd4a1
 
 
44fe76b
dbfd4a1
44fe76b
 
dbfd4a1
44fe76b
 
dbfd4a1
44fe76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline

torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

# Load the pipeline
model_path = "CompVis/stable-diffusion-v1-4"
sd_pipeline = DiffusionPipeline.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to(torch_device)

# Load textual inversions
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")

# Update style token dictionary
style_token_dict = {
    "Illustration Style": '<illustration-style>',
    "Line Art":'<line-art>',
    "Hitokomoru Style":'<hitokomoru-style-nao>',
    "Marc Allante": '<Marc_Allante>',
    "Midjourney":'<midjourney-style>',
    "Hanfu Anime": '<hanfu-anime-style>',
    "Birb Style": '<birb-style>'
}


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)
    scheduler.timesteps = scheduler.timesteps.to(torch.float32)

def pil_to_latent(input_im):
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()

def latents_to_pil(latents):
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

def generate_with_pipeline(prompt, num_inference_steps, guidance_scale, seed):
    generator = torch.Generator(device=torch_device).manual_seed(seed)
    image = sd_pipeline(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]
    return image

def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
    prompt = text + " " + style_token_dict[style]

    # Generate image with pipeline
    image_pipeline = generate_with_pipeline(prompt, inference_step, guidance_scale, seed)

    # For the guided image, we'll need to implement a custom pipeline or modify the existing one
    # This is a placeholder and would need to be implemented
    image_guide = image_pipeline  # This should be replaced with actual guided generation

    return image_pipeline, image_guide

title = "Stable Diffusion with Textual Inversion"
description = "A simple Gradio interface to infer Stable Diffusion and generate images with different art styles"
examples = [["A sweet potato farm", 'Illustration Style', 10, 4.5, 1, 'Grayscale', 100],
            ["Sky full of cotton candy", 'Line Art', 10, 9.5, 2, 'Bright', 200]]

demo = gr.Interface(inference, 
                    inputs = [gr.Textbox(label="Prompt", type="text"),
                              gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"), 
                              gr.Slider(10, 30, 10, step = 1, label="Inference steps"),
                              gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
                              gr.Slider(0, 10000, 1, step = 1, label="Seed"),
                              gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 
                                                                  'Symmetry', 'Saturation'], value="Grayscale"),
                              gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
                    outputs= [gr.Image(width=320, height=320, label="Generated art"),
                              gr.Image(width=320, height=320, label="Generated art with guidance")],
                    title=title,
                    description=description,
                    examples=examples)

demo.launch()