import gradio as gr import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from diffusers import DiffusionPipeline import requests from bs4 import BeautifulSoup import os import time import threading from PIL import Image import numpy as np # ====================== # Configuration # ====================== CONFIG = { "scraping": { "search_url": "https://www.pexels.com/search/{query}/", "headers": { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" }, "max_images": 100, "scrape_time": 10 }, "training": { "batch_size": 4, "epochs": 10, "lr": 0.0002, "latent_dim": 100, "img_size": 64, "num_workers": 0 }, "paths": { "dataset_dir": "scraped_data", "model_save": "text2img_model.pth" } } # ====================== # Web Scraping Module # ====================== class WebScraper: def __init__(self): self.stop_event = threading.Event() self.scraped_data = [] self._lock = threading.Lock() def __getstate__(self): state = self.__dict__.copy() del state['stop_event'] del state['_lock'] return state def __setstate__(self, state): self.__dict__.update(state) self.stop_event = threading.Event() self._lock = threading.Lock() def scrape_images(self, query): with self._lock: search_url = CONFIG["scraping"]["search_url"].format(query=query) try: response = requests.get(search_url, headers=CONFIG["scraping"]["headers"]) soup = BeautifulSoup(response.content, 'html.parser') img_tags = soup.find_all('img', {'class': 'photo-item__img'}) for img in img_tags[:CONFIG["scraping"]["max_images"]]: if self.stop_event.is_set(): break img_url = img['src'] try: img_data = requests.get(img_url).content img_name = f"{int(time.time())}.jpg" img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name) with open(img_path, 'wb') as f: f.write(img_data) self.scraped_data.append({"text": query, "image": img_path}) except Exception as e: print(f"Error downloading image: {e}") except Exception as e: print(f"Scraping error: {e}") def start_scraping(self, query): self.stop_event.clear() os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True) thread = threading.Thread(target=self.scrape_images, args=(query,)) thread.start() return "Scraping started..." # ====================== # Dataset and Models # ====================== class TextImageDataset(Dataset): def __init__(self, data, transform=None): self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] image = Image.open(item["image"]).convert('RGB') return {"text": item["text"], "image": image} class TextConditionedGenerator(nn.Module): def __init__(self): super().__init__() self.text_embedding = nn.Embedding(1000, 128) self.model = nn.Sequential( nn.Linear(128 + CONFIG["training"]["latent_dim"], 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2), nn.Linear(512, 3 * CONFIG["training"]["img_size"] ** 2), nn.Tanh() ) def forward(self, text, noise): text_emb = self.text_embedding(text) combined = torch.cat([text_emb, noise], 1) return self.model(combined).view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"]) # ====================== # Training Utilities # ====================== def train_model(scraper, progress=gr.Progress()): dataset = TextImageDataset(scraper.scraped_data) dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True) generator = TextConditionedGenerator() discriminator = nn.Sequential( nn.Linear(3 * CONFIG["training"]["img_size"] ** 2, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1), nn.Sigmoid() ) optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"]) optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"]) criterion = nn.BCELoss() for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"): for batch in dataloader: real_imgs = batch["image"] real_labels = torch.ones(real_imgs.size(0), 1) noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"]) # Discriminator training optimizer_D.zero_grad() real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels) fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise) fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), torch.zeros_like(real_labels)) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step() # Generator training optimizer_G.zero_grad() g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64**2)), torch.ones_like(real_labels)) g_loss.backward() optimizer_G.step() torch.save(generator.state_dict(), CONFIG["paths"]["model_save"]) return "Training completed!" # ====================== # Inference Modules # ====================== class ModelRunner: def __init__(self): self.pretrained_pipe = None self.custom_model = None def load_pretrained(self): if not self.pretrained_pipe: self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") return self.pretrained_pipe def load_custom(self): if not self.custom_model: model = TextConditionedGenerator() model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu')) self.custom_model = model return self.custom_model # ====================== # Gradio Interface # ====================== with gr.Blocks() as app: scraper_state = gr.State(WebScraper) model_runner_state = gr.State(ModelRunner) with gr.Row(): with gr.Column(): query_input = gr.Textbox(label="Search Query") scrape_btn = gr.Button("Start Scraping") scrape_status = gr.Textbox(label="Scraping Status") train_btn = gr.Button("Start Training") training_status = gr.Textbox(label="Training Status") with gr.Column(): prompt_input = gr.Textbox(label="Generation Prompt") model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained") generate_btn = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") scrape_btn.click( lambda scraper, query: scraper.start_scraping(query), [scraper_state, query_input], scrape_status ) train_btn.click( lambda scraper: train_model(scraper), [scraper_state], training_status ) generate_btn.click( lambda prompt, model_type, runner: generate_image(prompt, model_type, runner), [prompt_input, model_choice, model_runner_state], output_image ) def generate_image(prompt, model_type, runner): if model_type == "Pretrained": pipe = runner.load_pretrained() image = pipe(prompt).images[0] else: model = runner.load_custom() noise = torch.randn(1, CONFIG["training"]["latent_dim"]) with torch.no_grad(): fake = model(torch.randint(0, 1000, (1,)), noise) image = fake.squeeze().permute(1, 2, 0).numpy() image = (image + 1) / 2 return Image.fromarray((image * 255).astype(np.uint8)) if __name__ == "__main__": app.launch()