Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tensorflow.keras.datasets import mnist | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense, Reshape, Flatten | |
from tensorflow.keras.layers import Conv2D, Conv2DTranspose | |
from tensorflow.keras.layers import LeakyReLU, Dropout | |
from tensorflow.keras.optimizers import Adam | |
np.random.seed(42) | |
# Загрузка и предобработка данных MNIST | |
(X_train, _), (_, _) = mnist.load_data() | |
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 | |
X_train = np.expand_dims(X_train, axis=3) | |
# Создание и компиляция модели GAN | |
generator = Sequential() | |
generator.add(Dense(7 * 7 * 256, input_dim=100)) | |
generator.add(LeakyReLU(alpha=0.2)) | |
generator.add(Reshape((7, 7, 256))) | |
generator.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same')) | |
generator.add(LeakyReLU(alpha=0.2)) | |
generator.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same')) | |
generator.add(LeakyReLU(alpha=0.2)) | |
generator.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same')) | |
discriminator = Sequential() | |
discriminator.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(28, 28, 1))) | |
discriminator.add(LeakyReLU(alpha=0.2)) | |
discriminator.add(Dropout(0.3)) | |
discriminator.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same')) | |
discriminator.add(LeakyReLU(alpha=0.2)) | |
discriminator.add(Dropout(0.3)) | |
discriminator.add(Flatten()) | |
discriminator.add(Dense(1, activation='sigmoid')) | |
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy']) | |
discriminator.trainable = False | |
gan = Sequential() | |
gan.add(generator) | |
gan.add(discriminator) | |
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5)) | |
# Функция для генерации изображений на основе запроса | |
def generate_images(prompt): | |
num_images = 10 | |
noise = np.random.normal(0, 1, (num_images, 100)) | |
generated_images = generator.predict(noise) | |
generated_images = (generated_images * 127.5 + 127.5).astype(np.uint8) | |
return generated_images.reshape((num_images, 28, 28)) | |
# Адаптация функции под Gradio | |
def gradio_generate_images(prompt): | |
images = generate_images(prompt) | |
image_list = [] | |
for img in images: | |
image_list.append(img) | |
return image_list | |
# Запуск Gradio приложения | |
iface = gr.Interface( | |
fn=gradio_generate_images, | |
inputs="text", | |
outputs="image", | |
interpretation="default", | |
title="GAN Image Generation Demo", | |
description="Enter a prompt and generate images based on the prompt using GAN.", | |
example="smiley face" | |
) | |
iface.launch() | |