Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from diffusers import DiffusionPipeline | |
import requests | |
from bs4 import BeautifulSoup | |
import os | |
import time | |
import threading | |
from PIL import Image | |
import numpy as np | |
# ====================== | |
# Configuration | |
# ====================== | |
CONFIG = { | |
"scraping": { | |
"search_url": "https://www.pexels.com/search/{query}/", | |
"headers": { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" | |
}, | |
"max_images": 100, | |
"progress_interval": 1 | |
}, | |
"training": { | |
"batch_size": 4, | |
"epochs": 10, | |
"lr": 0.0002, | |
"latent_dim": 100, | |
"img_size": 64, | |
"num_workers": 0, | |
"progress_interval": 0.5 | |
}, | |
"paths": { | |
"dataset_dir": "scraped_data", | |
"model_save": "text2img_model.pth" | |
} | |
} | |
# ====================== | |
# Web Scraping Module | |
# ====================== | |
class WebScraper: | |
def __init__(self): | |
self.stop_event = threading.Event() | |
self.scraped_data = [] | |
self._lock = threading.Lock() | |
self.scraping_progress = 0 | |
self.scraped_count = 0 | |
self.total_images = 0 | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
del state['stop_event'] | |
del state['_lock'] | |
return state | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
self.stop_event = threading.Event() | |
self._lock = threading.Lock() | |
def scrape_images(self, query): | |
with self._lock: | |
self.scraping_progress = 0 | |
self.scraped_count = 0 | |
search_url = CONFIG["scraping"]["search_url"].format(query=query) | |
try: | |
response = requests.get(search_url, headers=CONFIG["scraping"]["headers"]) | |
soup = BeautifulSoup(response.content, 'html.parser') | |
img_tags = soup.find_all('img', {'class': 'photo-item__img'}) | |
self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"]) | |
for idx, img in enumerate(img_tags[:CONFIG["scraping"]["max_images"]]): | |
if self.stop_event.is_set(): | |
break | |
img_url = img['src'] | |
try: | |
img_data = requests.get(img_url).content | |
img_name = f"{int(time.time())}_{idx}.jpg" | |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name) | |
with open(img_path, 'wb') as f: | |
f.write(img_data) | |
self.scraped_data.append({"text": query, "image": img_path}) | |
self.scraped_count = idx + 1 | |
self.scraping_progress = (idx + 1) / self.total_images * 100 | |
except Exception as e: | |
print(f"Error downloading image: {e}") | |
time.sleep(0.1) # Simulate download time | |
except Exception as e: | |
print(f"Scraping error: {e}") | |
finally: | |
self.scraping_progress = 100 | |
def start_scraping(self, query): | |
self.stop_event.clear() | |
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True) | |
thread = threading.Thread(target=self.scrape_images, args=(query,)) | |
thread.start() | |
return "Scraping started..." | |
# ====================== | |
# Dataset and Models | |
# ====================== | |
class TextImageDataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
image = Image.open(item["image"]).convert('RGB') | |
image = torch.randn(3, 64, 64) # Simplified for example | |
return {"text": item["text"], "image": image} | |
class TextConditionedGenerator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.text_embedding = nn.Embedding(1000, 128) | |
self.model = nn.Sequential( | |
nn.Linear(128 + 100, 256), | |
nn.LeakyReLU(0.2), | |
nn.Linear(256, 512), | |
nn.BatchNorm1d(512), | |
nn.LeakyReLU(0.2), | |
nn.Linear(512, 3*64*64), | |
nn.Tanh() | |
) | |
def forward(self, text, noise): | |
text_emb = self.text_embedding(text) | |
combined = torch.cat([text_emb, noise], 1) | |
return self.model(combined).view(-1, 3, 64, 64) | |
# ====================== | |
# Training Utilities | |
# ====================== | |
def train_model(scraper, progress=gr.Progress()): | |
if len(scraper.scraped_data) == 0: | |
return "Error: No images scraped! Scrape images first." | |
dataset = TextImageDataset(scraper.scraped_data) | |
dataloader = DataLoader(dataset, | |
batch_size=CONFIG["training"]["batch_size"], | |
shuffle=True) | |
generator = TextConditionedGenerator() | |
discriminator = nn.Sequential( | |
nn.Linear(3*64*64, 512), | |
nn.LeakyReLU(0.2), | |
nn.Linear(512, 1), | |
nn.Sigmoid() | |
) | |
optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"]) | |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"]) | |
criterion = nn.BCELoss() | |
total_batches = len(dataloader) | |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Epochs"): | |
for batch_idx, batch in enumerate(dataloader): | |
real_imgs = torch.randn(4, 3, 64, 64) # Simplified data | |
real_labels = torch.ones(real_imgs.size(0), 1) | |
noise = torch.randn(real_imgs.size(0), 100) | |
# Train discriminator | |
optimizer_D.zero_grad() | |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels) | |
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise) | |
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), torch.zeros_like(real_labels)) | |
d_loss = (real_loss + fake_loss) / 2 | |
d_loss.backward() | |
optimizer_D.step() | |
# Train generator | |
optimizer_G.zero_grad() | |
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels)) | |
g_loss.backward() | |
optimizer_G.step() | |
progress( | |
(epoch + (batch_idx+1)/total_batches) / CONFIG["training"]["epochs"], | |
desc=f"Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches}", | |
unit="epoch" | |
) | |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"]) | |
return f"Training complete! Used {len(dataset)} samples" | |
# ====================== | |
# Gradio Interface | |
# ====================== | |
def create_interface(): | |
with gr.Blocks() as app: | |
scraper = gr.State(WebScraper) | |
model_runner = gr.State(ModelRunner) | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox(label="Search Query") | |
scrape_btn = gr.Button("Start Scraping") | |
scrape_status = gr.Textbox(label="Scraping Status") | |
scraping_progress = gr.Textbox(label="Scraping Progress", value="0% (0/0)") | |
train_btn = gr.Button("Start Training") | |
training_status = gr.Textbox(label="Training Status") | |
training_progress = gr.Textbox(label="Training Progress", value="Not started") | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Generation Prompt") | |
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained") | |
generate_btn = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
# Real-time updates using event triggers | |
def update_scraping_progress(scraper): | |
return f"{scraper.scraping_progress:.1f}% ({scraper.scraped_count}/{scraper.total_images})" | |
def update_training_progress(): | |
if os.path.exists(CONFIG["paths"]["model_save"]): | |
stats = os.stat(CONFIG["paths"]["model_save"]) | |
return f"Model size: {stats.st_size//1024}KB" | |
return "No trained model" | |
# Set up periodic updates | |
scraping_progress.change( | |
update_scraping_progress, | |
inputs=[scraper], | |
outputs=[scraping_progress], | |
every=CONFIG["scraping"]["progress_interval"] | |
) | |
training_progress.change( | |
update_training_progress, | |
outputs=[training_progress], | |
every=CONFIG["training"]["progress_interval"] | |
) | |
# Event handlers | |
scrape_btn.click( | |
lambda s, q: s.start_scraping(q), | |
[scraper, query_input], | |
[scrape_status] | |
) | |
train_btn.click( | |
lambda s: train_model(s), | |
[scraper], | |
[training_status, training_progress] | |
) | |
generate_btn.click( | |
lambda p, m, r: generate_image(p, m, r), | |
[prompt_input, model_choice, model_runner], | |
[output_image] | |
) | |
return app | |
def generate_image(prompt, model_type, runner): | |
if model_type == "Pretrained": | |
pipe = runner.load_pretrained() | |
image = pipe(prompt).images[0] | |
else: | |
model = runner.load_custom() | |
noise = torch.randn(1, 100) | |
with torch.no_grad(): | |
fake = model(torch.randint(0, 1000, (1,)), noise) | |
image = fake.squeeze().permute(1, 2, 0).numpy() | |
image = (image + 1) / 2 | |
return Image.fromarray((image * 255).astype(np.uint8)) | |
class ModelRunner: | |
def __init__(self): | |
self.pretrained_pipe = None | |
def load_pretrained(self): | |
if not self.pretrained_pipe: | |
self.pretrained_pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0" | |
) | |
return self.pretrained_pipe | |
def load_custom(self): | |
model = TextConditionedGenerator() | |
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu')) | |
return model | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() |