import gradio as gr import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from diffusers import DiffusionPipeline import requests from bs4 import BeautifulSoup import os import time import threading from PIL import Image import numpy as np # ====================== # Configuration # ====================== CONFIG = { "scraping": { "search_url": "https://www.pexels.com/search/{query}/", "headers": { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" }, "max_images": 100, "progress_interval": 1 }, "training": { "batch_size": 4, "epochs": 10, "lr": 0.0002, "latent_dim": 100, "img_size": 64, "num_workers": 0, "progress_interval": 0.5 }, "paths": { "dataset_dir": "scraped_data", "model_save": "text2img_model.pth" } } # ====================== # Web Scraping Module # ====================== class WebScraper: def __init__(self): self.stop_event = threading.Event() self.scraped_data = [] self._lock = threading.Lock() self.scraping_progress = 0 self.scraped_count = 0 self.total_images = 0 def __getstate__(self): state = self.__dict__.copy() del state['stop_event'] del state['_lock'] return state def __setstate__(self, state): self.__dict__.update(state) self.stop_event = threading.Event() self._lock = threading.Lock() def scrape_images(self, query): with self._lock: self.scraping_progress = 0 self.scraped_count = 0 search_url = CONFIG["scraping"]["search_url"].format(query=query) try: response = requests.get(search_url, headers=CONFIG["scraping"]["headers"]) soup = BeautifulSoup(response.content, 'html.parser') img_tags = soup.find_all('img', {'class': 'photo-item__img'}) self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"]) for idx, img in enumerate(img_tags[:CONFIG["scraping"]["max_images"]]): if self.stop_event.is_set(): break img_url = img['src'] try: img_data = requests.get(img_url).content img_name = f"{int(time.time())}_{idx}.jpg" img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name) with open(img_path, 'wb') as f: f.write(img_data) self.scraped_data.append({"text": query, "image": img_path}) self.scraped_count = idx + 1 self.scraping_progress = (idx + 1) / self.total_images * 100 except Exception as e: print(f"Error downloading image: {e}") time.sleep(0.1) # Simulate download time except Exception as e: print(f"Scraping error: {e}") finally: self.scraping_progress = 100 def start_scraping(self, query): self.stop_event.clear() os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True) thread = threading.Thread(target=self.scrape_images, args=(query,)) thread.start() return "Scraping started..." # ====================== # Dataset and Models # ====================== class TextImageDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] image = Image.open(item["image"]).convert('RGB') image = torch.randn(3, 64, 64) # Simplified for example return {"text": item["text"], "image": image} class TextConditionedGenerator(nn.Module): def __init__(self): super().__init__() self.text_embedding = nn.Embedding(1000, 128) self.model = nn.Sequential( nn.Linear(128 + 100, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2), nn.Linear(512, 3*64*64), nn.Tanh() ) def forward(self, text, noise): text_emb = self.text_embedding(text) combined = torch.cat([text_emb, noise], 1) return self.model(combined).view(-1, 3, 64, 64) # ====================== # Training Utilities # ====================== def train_model(scraper, progress=gr.Progress()): if len(scraper.scraped_data) == 0: return "Error: No images scraped! Scrape images first." dataset = TextImageDataset(scraper.scraped_data) dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True) generator = TextConditionedGenerator() discriminator = nn.Sequential( nn.Linear(3*64*64, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1), nn.Sigmoid() ) optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"]) optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"]) criterion = nn.BCELoss() total_batches = len(dataloader) for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Epochs"): for batch_idx, batch in enumerate(dataloader): real_imgs = torch.randn(4, 3, 64, 64) # Simplified data real_labels = torch.ones(real_imgs.size(0), 1) noise = torch.randn(real_imgs.size(0), 100) # Train discriminator optimizer_D.zero_grad() real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels) fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise) fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), torch.zeros_like(real_labels)) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step() # Train generator optimizer_G.zero_grad() g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels)) g_loss.backward() optimizer_G.step() progress( (epoch + (batch_idx+1)/total_batches) / CONFIG["training"]["epochs"], desc=f"Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches}", unit="epoch" ) torch.save(generator.state_dict(), CONFIG["paths"]["model_save"]) return f"Training complete! Used {len(dataset)} samples" # ====================== # Gradio Interface # ====================== def create_interface(): with gr.Blocks() as app: scraper = gr.State(lambda: WebScraper()) model_runner = gr.State(lambda: ModelRunner()) with gr.Row(): with gr.Column(): query_input = gr.Textbox(label="Search Query") scrape_btn = gr.Button("Start Scraping") scrape_status = gr.Textbox(label="Scraping Status") scraping_progress = gr.Textbox(label="Scraping Progress", value="0% (0/0)") train_btn = gr.Button("Start Training") training_status = gr.Textbox(label="Training Status") training_progress = gr.Textbox(label="Training Progress", value="Epoch 0/10 | Batch 0/0") with gr.Column(): prompt_input = gr.Textbox(label="Generation Prompt") model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained") generate_btn = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") # Scraping monitoring def monitor_scraping(scraper): while True: if hasattr(scraper, 'scraping_progress'): yield f"{scraper.scraping_progress:.1f}% ({scraper.scraped_count}/{scraper.total_images})" else: yield "0% (0/0)" time.sleep(CONFIG["scraping"]["progress_interval"]) # Training monitoring def monitor_training(): while True: if os.path.exists(CONFIG["paths"]["model_save"]): with open(CONFIG["paths"]["model_save"], 'rb') as f: stats = os.stat(f.fileno()) yield f"Model size: {stats.st_size//1024}KB" else: yield "No trained model" time.sleep(1) app.load( monitor_scraping, inputs=[scraper], outputs=[scraping_progress], every=CONFIG["scraping"]["progress_interval"] ) app.load( monitor_training, outputs=[training_progress], every=1 ) # Event handlers scrape_btn.click( lambda s, q: s.start_scraping(q), [scraper, query_input], [scrape_status] ) train_btn.click( lambda s: train_model(s), [scraper], [training_status] ) generate_btn.click( lambda p, m, r: generate_image(p, m, r), [prompt_input, model_choice, model_runner], [output_image] ) return app def generate_image(prompt, model_type, runner): if model_type == "Pretrained": pipe = runner.load_pretrained() image = pipe(prompt).images[0] else: model = runner.load_custom() noise = torch.randn(1, 100) with torch.no_grad(): fake = model(torch.randint(0, 1000, (1,)), noise) image = fake.squeeze().permute(1, 2, 0).numpy() image = (image + 1) / 2 return Image.fromarray((image * 255).astype(np.uint8)) class ModelRunner: def __init__(self): self.pretrained_pipe = None def load_pretrained(self): if not self.pretrained_pipe: self.pretrained_pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0" ) return self.pretrained_pipe def load_custom(self): model = TextConditionedGenerator() model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu')) return model if __name__ == "__main__": interface = create_interface() interface.launch()