Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from diffusers import DiffusionPipeline | |
import requests | |
from bs4 import BeautifulSoup | |
import os | |
import time | |
import threading | |
from PIL import Image | |
import io | |
import numpy as np | |
# ====================== | |
# Configuration | |
# ====================== | |
CONFIG = { | |
"scraping": { | |
"search_url": "https://www.pexels.com/search/{query}/", | |
"headers": { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" | |
}, | |
"max_images": 100, | |
"scrape_time": 10 # 3 hours in seconds (simulated for testing) | |
}, | |
"training": { | |
"batch_size": 4, | |
"epochs": 10, | |
"lr": 0.0002, | |
"latent_dim": 100, | |
"img_size": 64, | |
"num_workers": 0 | |
}, | |
"paths": { | |
"dataset_dir": "scraped_data", | |
"model_save": "text2img_model.pth" | |
} | |
} | |
# ====================== | |
# Web Scraping Module | |
# ====================== | |
class WebScraper: | |
def __init__(self): | |
self.stop_event = threading.Event() | |
self.scraped_data = [] | |
def scrape_images(self, query): | |
search_url = CONFIG["scraping"]["search_url"].format(query=query) | |
try: | |
response = requests.get(search_url, headers=CONFIG["scraping"]["headers"]) | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Extract image URLs (example selector - needs adjustment for actual site) | |
img_tags = soup.find_all('img', {'class': 'photo-item__img'}) | |
for img in img_tags[:CONFIG["scraping"]["max_images"]]: | |
if self.stop_event.is_set(): | |
break | |
img_url = img['src'] | |
try: | |
img_data = requests.get(img_url).content | |
img_name = f"{int(time.time())}.jpg" | |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name) | |
with open(img_path, 'wb') as f: | |
f.write(img_data) | |
# Store text-image pair (text = query) | |
self.scraped_data.append({"text": query, "image": img_path}) | |
except Exception as e: | |
print(f"Error downloading image: {e}") | |
except Exception as e: | |
print(f"Scraping error: {e}") | |
def start_scraping(self, query): | |
self.stop_event.clear() | |
if not os.path.exists(CONFIG["paths"]["dataset_dir"]): | |
os.makedirs(CONFIG["paths"]["dataset_dir"]) | |
thread = threading.Thread(target=self.scrape_images, args=(query,)) | |
thread.start() | |
return "Scraping started..." | |
# ====================== | |
# Dataset and Models | |
# ====================== | |
class TextImageDataset(Dataset): | |
def __init__(self, data, transform=None): | |
self.data = data | |
self.transform = transform | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
image = Image.open(item["image"]).convert('RGB') | |
if self.transform: | |
image = self.transform(image) | |
return {"text": item["text"], "image": image} | |
# Simplified Text-to-Image Generator | |
class TextConditionedGenerator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.text_embedding = nn.Embedding(1000, 128) # Simplified text embedding | |
self.model = nn.Sequential( | |
nn.Linear(128 + CONFIG["training"]["latent_dim"], 256), | |
nn.LeakyReLU(0.2), | |
nn.Linear(256, 512), | |
nn.BatchNorm1d(512), | |
nn.LeakyReLU(0.2), | |
nn.Linear(512, 3 * CONFIG["training"]["img_size"] ** 2), | |
nn.Tanh() | |
) | |
def forward(self, text, noise): | |
text_emb = self.text_embedding(text) | |
combined = torch.cat([text_emb, noise], 1) | |
img = self.model(combined) | |
return img.view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"]) | |
# ====================== | |
# Training Utilities | |
# ====================== | |
def train_model(scraper, progress=gr.Progress()): | |
dataset = TextImageDataset(scraper.scraped_data) | |
dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True) | |
generator = TextConditionedGenerator() | |
discriminator = nn.Sequential( | |
nn.Linear(3 * CONFIG["training"]["img_size"] ** 2, 512), | |
nn.LeakyReLU(0.2), | |
nn.Linear(512, 1), | |
nn.Sigmoid() | |
) | |
optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"]) | |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"]) | |
criterion = nn.BCELoss() | |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"): | |
for i, batch in enumerate(dataloader): | |
# Train discriminator | |
real_imgs = batch["image"] | |
real_labels = torch.ones(real_imgs.size(0), 1) | |
noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"]) | |
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise) | |
fake_labels = torch.zeros(real_imgs.size(0), 1) | |
optimizer_D.zero_grad() | |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels) | |
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), fake_labels) | |
d_loss = real_loss + fake_loss | |
d_loss.backward() | |
optimizer_D.step() | |
# Train generator | |
optimizer_G.zero_grad() | |
validity = discriminator(fake_imgs.view(-1, 3*64**2)) | |
g_loss = criterion(validity, torch.ones_like(validity)) | |
g_loss.backward() | |
optimizer_G.step() | |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"]) | |
return "Training completed!" | |
# ====================== | |
# Inference Modules | |
# ====================== | |
class ModelRunner: | |
def __init__(self): | |
self.pretrained_pipe = None | |
self.custom_model = None | |
def load_pretrained(self): | |
if self.pretrained_pipe is None: | |
self.pretrained_pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0" | |
) | |
return self.pretrained_pipe | |
def load_custom(self): | |
if self.custom_model is None: | |
model = TextConditionedGenerator() | |
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"])) | |
self.custom_model = model | |
return self.custom_model | |
# ====================== | |
# Gradio Interface | |
# ====================== | |
with gr.Blocks() as app: | |
scraper = WebScraper() | |
model_runner = ModelRunner() | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox(label="Search Query") | |
scrape_btn = gr.Button("Start Scraping") | |
scrape_status = gr.Textbox(label="Scraping Status") | |
train_btn = gr.Button("Start Training") | |
training_status = gr.Textbox(label="Training Status") | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Generation Prompt") | |
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type") | |
generate_btn = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
# Event Handlers | |
scrape_btn.click( | |
fn=scraper.start_scraping, | |
inputs=query_input, | |
outputs=scrape_status | |
) | |
train_btn.click( | |
fn=train_model, | |
inputs=[scraper], | |
outputs=training_status | |
) | |
generate_btn.click( | |
fn=lambda prompt, model_type: generate_image(prompt, model_type, model_runner), | |
inputs=[prompt_input, model_choice], | |
outputs=output_image | |
) | |
def generate_image(prompt, model_type, runner): | |
if model_type == "Pretrained": | |
pipe = runner.load_pretrained() | |
image = pipe(prompt).images[0] | |
else: | |
model = runner.load_custom() | |
# Simplified generation process | |
noise = torch.randn(1, CONFIG["training"]["latent_dim"]) | |
fake = model(torch.randint(0, 1000, (1,)), noise).detach() | |
image = fake.squeeze().permute(1,2,0).numpy() | |
image = (image + 1) / 2 # Scale to [0,1] | |
return Image.fromarray((image * 255).astype(np.uint8)) | |
if __name__ == "__main__": | |
app.launch() |