Update app.py
Browse files
app.py
CHANGED
@@ -1,88 +1,107 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
-
from PIL import Image
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
8 |
"DALL-E 2",
|
9 |
-
"VQGAN
|
|
|
|
|
10 |
"VQGAN",
|
11 |
-
"
|
12 |
-
"VQGAN
|
13 |
-
"
|
14 |
-
"VQGAN
|
15 |
-
"DALL-E Mini-Fusion",
|
16 |
-
"VQGAN-CLIP-Ada-Mini-Fusion",
|
17 |
-
"DALL-E 2 Mini",
|
18 |
-
"VQGAN-CLIP-Ada-Mini-2",
|
19 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Функция генерации изображения
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
elif model == "VQGAN":
|
33 |
-
|
34 |
-
|
35 |
-
elif model == "
|
36 |
-
|
37 |
-
|
38 |
-
elif model == "VQGAN
|
39 |
-
|
40 |
-
|
41 |
-
elif model == "
|
42 |
-
|
43 |
-
|
44 |
-
elif model == "VQGAN
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
image = vqgan_clip_ada_mini_2.generate(prompt, negative_prompt)
|
59 |
-
|
60 |
-
image = Image.fromarray(image)
|
61 |
return image
|
62 |
|
63 |
# Функция улучшения качества изображения
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
return image
|
68 |
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
-
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
+
from transformers import AutoModelForConditionalGeneration, AutoTokenizer
|
4 |
+
|
5 |
+
# Базовые настройки
|
6 |
+
|
7 |
+
prompt = gr.inputs.Textbox(label="Промпт")
|
8 |
+
model_list = [
|
9 |
"DALL-E 2",
|
10 |
+
"VQGAN+CLIP",
|
11 |
+
"BigGAN",
|
12 |
+
"StyleGAN2",
|
13 |
"VQGAN",
|
14 |
+
"CLIP",
|
15 |
+
"VQGAN+CLIP-Vanilla",
|
16 |
+
"VQGAN+CLIP-Cutout",
|
17 |
+
"VQGAN+CLIP-RandomizedCutout",
|
|
|
|
|
|
|
|
|
18 |
]
|
19 |
+
model = gr.widgets.ToggleButtons(options=model_list, label="Модель")
|
20 |
+
|
21 |
+
# Расширенные настройки
|
22 |
+
|
23 |
+
negative_prompt = gr.inputs.Textbox(label="Отрицательный промпт")
|
24 |
+
sampling_method = gr.inputs.RadioButtons(options=["greedy", "top-k", "nucleus"], label="Метод выборки")
|
25 |
+
sampling_steps = gr.inputs.Number(min=1, max=100, label="Количество шагов выборки")
|
26 |
+
cfg_scale = gr.inputs.Number(min=0.1, max=1.0, label="Масштаб CFG")
|
27 |
+
seed = gr.inputs.Number(min=0, max=2**31 - 1, label="Случайное число")
|
28 |
+
|
29 |
+
# Улучшение качества
|
30 |
+
|
31 |
+
upscale_algorithm = gr.inputs.RadioButtons(options=["bicubic", "lanczos"], label="Алгоритм увеличения")
|
32 |
|
33 |
# Функция генерации изображения
|
34 |
+
|
35 |
+
def generate_image(prompt, model, negative_prompt, sampling_method, sampling_steps, cfg_scale, seed):
|
36 |
+
if model == "DALL-E 2":
|
37 |
+
model = AutoModelForConditionalGeneration.from_pretrained("google/dalle-2-1024")
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained("google/dalle-2-1024")
|
39 |
+
elif model == "VQGAN+CLIP":
|
40 |
+
model = AutoModelForConditionalGeneration.from_pretrained("openai/vqgan-clip")
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained("openai/vqgan-clip")
|
42 |
+
elif model == "BigGAN":
|
43 |
+
model = AutoModelForConditionalGeneration.from_pretrained("karras2022/biggan-deep-256")
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained("karras2022/biggan-deep-256")
|
45 |
+
elif model == "StyleGAN2":
|
46 |
+
model = AutoModelForConditionalGeneration.from_pretrained("NVlabs/stylegan2-ada")
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained("NVlabs/stylegan2-ada")
|
48 |
elif model == "VQGAN":
|
49 |
+
model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024")
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024")
|
51 |
+
elif model == "CLIP":
|
52 |
+
model = AutoModelForConditionalGeneration.from_pretrained("openai/clip")
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained("openai/clip")
|
54 |
+
elif model == "VQGAN+CLIP-Vanilla":
|
55 |
+
model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_vanilla")
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_vanilla")
|
57 |
+
elif model == "VQGAN+CLIP-Cutout":
|
58 |
+
model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_cutout")
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_cutout")
|
60 |
+
elif model == "VQGAN+CLIP-RandomizedCutout":
|
61 |
+
model = AutoModelForConditionalGeneration.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_randomized_cutout")
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained("vqgan/vqgan_imagenet_f16_1024_clip_randomized_cutout")
|
63 |
+
|
64 |
+
prompt = f"{prompt} {negative_prompt}"
|
65 |
+
|
66 |
+
image = model.generate(
|
67 |
+
text=prompt,
|
68 |
+
sampling_method=sampling_method,
|
69 |
+
sampling_steps=sampling_steps,
|
70 |
+
cfg_scale=cfg_scale,
|
71 |
+
seed=seed,
|
72 |
+
)
|
73 |
+
|
|
|
|
|
|
|
74 |
return image
|
75 |
|
76 |
# Функция улучшения качества изображения
|
77 |
+
|
78 |
+
def upscale_image(image, upscale_algorithm):
|
79 |
+
if upscale_algorithm == "bicubic":
|
80 |
+
image = cv2.resize(image, dsize=(image.shape[1] * 2, image.shape[0] * 2), interpolation=cv2.INTER_CUBIC)
|
81 |
+
elif upscale_algorithm == "lanczos":
|
82 |
+
image = cv2.resize(image, dsize=(image.shape[1] * 2, image.shape[0] * 2), interpolation=cv2.INTER_LANCZOS4)
|
83 |
return image
|
84 |
|
85 |
+
# Функция отображения изображения
|
86 |
+
|
87 |
+
def show_image(image):
|
88 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
89 |
+
cv2.imshow("Image", image)
|
90 |
+
cv2.waitKey(0)
|
91 |
+
|
92 |
+
# Основная функция
|
93 |
+
|
94 |
+
def main():
|
95 |
+
image = generate_image(prompt.value, model.value, negative_prompt.value, sampling_method.value, sampling_steps.value, cfg_scale.value, seed.value)
|
96 |
+
|
97 |
+
# Если выбрано улучшение качества изображения
|
98 |
+
|
99 |
+
if upscale_algorithm.value != "none":
|
100 |
+
image = upscale_image(image, upscale_algorithm.value)
|
101 |
+
|
102 |
+
# Отображение изображения
|
103 |
+
|
104 |
+
show_image(image)
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
+
main()
|