Kvikontent commited on
Commit
9396be5
·
1 Parent(s): 5169d9a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from tensorflow.keras.datasets import mnist
5
+ from tensorflow.keras.models import Sequential
6
+ from tensorflow.keras.layers import Dense, Reshape, Flatten
7
+ from tensorflow.keras.layers import Conv2D, Conv2DTranspose
8
+ from tensorflow.keras.layers import LeakyReLU, Dropout
9
+ from tensorflow.keras.optimizers import Adam
10
+
11
+ np.random.seed(42)
12
+
13
+ # Загрузка и предобработка данных MNIST
14
+ (X_train, _), (_, _) = mnist.load_data()
15
+ X_train = (X_train.astype(np.float32) - 127.5) / 127.5
16
+ X_train = np.expand_dims(X_train, axis=3)
17
+
18
+ # Создание и компиляция модели GAN
19
+ generator = Sequential()
20
+ generator.add(Dense(7 * 7 * 256, input_dim=100))
21
+ generator.add(LeakyReLU(alpha=0.2))
22
+ generator.add(Reshape((7, 7, 256)))
23
+ generator.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same'))
24
+ generator.add(LeakyReLU(alpha=0.2))
25
+ generator.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'))
26
+ generator.add(LeakyReLU(alpha=0.2))
27
+ generator.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same'))
28
+
29
+ discriminator = Sequential()
30
+ discriminator.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(28, 28, 1)))
31
+ discriminator.add(LeakyReLU(alpha=0.2))
32
+ discriminator.add(Dropout(0.3))
33
+ discriminator.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
34
+ discriminator.add(LeakyReLU(alpha=0.2))
35
+ discriminator.add(Dropout(0.3))
36
+ discriminator.add(Flatten())
37
+ discriminator.add(Dense(1, activation='sigmoid'))
38
+
39
+ discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
40
+ discriminator.trainable = False
41
+ gan = Sequential()
42
+ gan.add(generator)
43
+ gan.add(discriminator)
44
+ gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
45
+
46
+ # Функция для генерации изображений на основе запроса
47
+ def generate_images(prompt):
48
+ num_images = 10
49
+ prompt_vector = preprocess_prompt(prompt) # Предобработка запроса, если необходимо
50
+ noise = np.repeat([prompt_vector], num_images, axis=0) # Генерация шума на основе запроса
51
+ generated_images = generator.predict(noise)
52
+ generated_images = (generated_images * 127.5 + 127.5).astype(np.uint8)
53
+ return generated_images.reshape((num_images, 28, 28))
54
+
55
+ # Адаптация функции под Gradio
56
+ def gradio_generate_images(prompt):
57
+ images = generate_images(prompt)
58
+ image_list = []
59
+ for img in images:
60
+ image_list.append(img)
61
+ return image_list
62
+
63
+ # Запуск Gradio приложения
64
+ iface = gr.Interface(
65
+ fn=gradio_generate_images,
66
+ inputs="text",
67
+ outputs="image",
68
+ interpretation="default",
69
+ title="GAN Image Generation Demo",
70
+ description="Enter a prompt and generate images based on the prompt using GAN.",
71
+ example="smiley face"
72
+ )
73
+ iface.launch()