File size: 4,611 Bytes
af10332
 
7d9be07
 
 
 
 
 
0619f29
7d9be07
 
 
0619f29
7d9be07
 
 
 
0619f29
7d9be07
 
 
 
 
 
 
 
 
 
 
 
 
a5b71c1
0619f29
7d9be07
 
 
 
 
 
 
 
 
 
 
 
 
 
0619f29
7d9be07
 
 
 
 
 
 
 
 
 
 
 
28786bc
 
7d9be07
 
 
 
 
 
 
 
 
 
 
0619f29
75d0567
0619f29
7d9be07
 
 
 
 
 
0619f29
70ba535
7d9be07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5b71c1
0619f29
7d9be07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr

from transformers import AutoModelForConditionalGeneration, AutoTokenizer

# Базовые настройки

prompt = gr.inputs.Textbox(label="Промпт")
model_list = [
    "DALL-E 2",
    "VQGAN+CLIP",
    "BigGAN",
    "StyleGAN2",
    "VQGAN",
    "CLIP",
    "VQGAN+CLIP-Vanilla",
    "VQGAN+CLIP-Cutout",
    "VQGAN+CLIP-RandomizedCutout",
]
model = gr.widgets.ToggleButtons(options=model_list, label="Модель")

# Расширенные настройки

negative_prompt = gr.inputs.Textbox(label="Отрицательный промпт")
sampling_method = gr.inputs.RadioButtons(options=["greedy", "top-k", "nucleus"], label="Метод выборки")
sampling_steps = gr.inputs.Number(min=1, max=100, label="Количество шагов выборки")
cfg_scale = gr.inputs.Number(min=0.1, max=1.0, label="Масштаб CFG")
seed = gr.inputs.Number(min=0, max=2**31 - 1, label="Случайное число")

# Улучшение качества

upscale_algorithm = gr.inputs.RadioButtons(options=["bicubic", "lanczos"], label="Алгоритм увеличения")

# Функция генерации изображения

def generate_image(prompt, model, negative_prompt, sampling_method, sampling_steps, cfg_scale, seed):
    if model == "DALL-E 2":
        model = AutoModelForConditionalGeneration.from_pretrained("google/dalle-2-1024")
        tokenizer = AutoTokenizer.from_pretrained("google/dalle-2-1024")
    elif model == "VQGAN+CLIP":
        model = AutoModelForConditionalGeneration.from_pretrained("openai/vqgan-clip")
        tokenizer = AutoTokenizer.from_pretrained("openai/vqgan-clip")
    elif model == "BigGAN":
        model = AutoModelForConditionalGeneration.from_pretrained("karras2022/biggan-deep-256")
        tokenizer = AutoTokenizer.from_pretrained("karras2022/biggan-deep-256")
    elif model == "StyleGAN2":
        model = AutoModelForConditionalGeneration.from_pretrained("NVlabs/stylegan2-ada")
        tokenizer = AutoTokenizer.from_pretrained("NVlabs/stylegan2-ada")
    elif model == "VQGAN":
        model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024")
        tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024")
    elif model == "CLIP":
        model = AutoModelForConditionalGeneration.from_pretrained("openai/clip")
        tokenizer = AutoTokenizer.from_pretrained("openai/clip")
    elif model == "VQGAN+CLIP-Vanilla":
        model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_vanilla")
        tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_vanilla")
    elif model == "VQGAN+CLIP-Cutout":
        model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_cutout")
        tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_cutout")
    elif model == "VQGAN+CLIP-RandomizedCutout":
        model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_randomized_cutout")
        tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_randomized_cutout")

    prompt = f"{prompt} {negative_prompt}"

    image = model.generate(
        text=prompt,
        sampling_method=sampling_method,
        sampling_steps=sampling_steps,
        cfg_scale=cfg_scale,
        seed=seed,
    )

    return image

# Функция улучшения качества изображения

def upscale_image(image, upscale_algorithm):
    if upscale_algorithm == "bicubic":
        image = cv2.resize(image, dsize=(image.shape[1] * 2, image.shape[0] * 2), interpolation=cv2.INTER_CUBIC)
    elif upscale_algorithm == "lanczos":
        image = cv2.resize(image, dsize=(image.shape[1] * 2, image.shape[0] * 2), interpolation=cv2.INTER_LANCZOS4)
    return image

# Функция отображения изображения

def show_image(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    cv2.imshow("Image", image)
    cv2.waitKey(0)

# Основная функция

def main():
    image = generate_image(prompt.value, model.value, negative_prompt.value, sampling_method.value, sampling_steps.value, cfg_scale.value, seed.value)

    # Если выбрано улучшение качества изображения

    if upscale_algorithm.value != "none":
        image = upscale_image(image, upscale_algorithm.value)

    # Отображение изображения

    show_image(image)

if __name__ == "__main__":
    main()