Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from diffusers import DiffusionPipeline | |
import requests | |
import os | |
import time | |
import threading | |
from PIL import Image | |
import numpy as np | |
# ====================== | |
# Configuration | |
# ====================== | |
CONFIG = { | |
"pexels_api_key": "HSknLvmKmOXuqXsE89NXzu6ysOqPr7FmHGObjaSdhTTmpFSuK5K7OaHn", | |
"scraping": { | |
"search_url": "https://api.pexels.com/v1/search?query={query}&per_page=80", | |
"max_images": 100, | |
"progress_interval": 1 | |
}, | |
"training": { | |
"batch_size": 4, | |
"epochs": 10, | |
"lr": 0.0002, | |
"latent_dim": 100, | |
"img_size": 64, | |
"num_workers": 0, | |
"progress_interval": 0.5 | |
}, | |
"paths": { | |
"dataset_dir": "scraped_data", | |
"model_save": "text2img_model.pth" | |
} | |
} | |
# ====================== | |
# Web Scraping Module (Now using Pexels API) | |
# ====================== | |
class WebScraper: | |
def __init__(self): | |
self.stop_event = threading.Event() | |
self.scraped_data = [] | |
self._lock = threading.Lock() | |
self.scraping_progress = 0 | |
self.scraped_count = 0 | |
self.total_images = 0 | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
del state['stop_event'] | |
del state['_lock'] | |
return state | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
self.stop_event = threading.Event() | |
self._lock = threading.Lock() | |
def scrape_images(self, query): | |
with self._lock: | |
self.scraping_progress = 0 | |
self.scraped_count = 0 | |
url = CONFIG["scraping"]["search_url"].format(query=query) | |
headers = { | |
"Authorization": CONFIG["pexels_api_key"] | |
} | |
try: | |
response = requests.get(url, headers=headers) | |
data = response.json() | |
photos = data.get("photos", []) | |
self.total_images = min(len(photos), CONFIG["scraping"]["max_images"]) | |
for idx, photo in enumerate(photos[:self.total_images]): | |
if self.stop_event.is_set(): | |
break | |
img_url = photo["src"]["large"] | |
try: | |
img_data = requests.get(img_url).content | |
img_name = f"{int(time.time())}_{idx}.jpg" | |
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True) | |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name) | |
with open(img_path, 'wb') as f: | |
f.write(img_data) | |
self.scraped_data.append({"text": query, "image": img_path}) | |
self.scraped_count = idx + 1 | |
self.scraping_progress = (idx + 1) / self.total_images * 100 | |
except Exception as e: | |
print(f"Error downloading image: {e}") | |
time.sleep(0.1) | |
except Exception as e: | |
print(f"API scraping error: {e}") | |
finally: | |
self.scraping_progress = 100 | |
def start_scraping(self, query): | |
self.scraped_data.clear() | |
self.stop_event.clear() | |
thread = threading.Thread(target=self.scrape_images, args=(query,)) | |
thread.start() | |
return "Scraping started..." | |
# ====================== | |
# Dataset and Models (Unchanged) | |
# ====================== | |
class TextImageDataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
try: | |
image = Image.open(item["image"]).convert('RGB') | |
image = image.resize((64, 64)) | |
image = np.array(image).transpose(2, 0, 1) / 127.5 - 1 | |
image = torch.tensor(image, dtype=torch.float32) | |
except Exception as e: | |
print(f"Error loading image: {e}") | |
image = torch.randn(3, 64, 64) | |
return {"text": item["text"], "image": image} | |
class TextConditionedGenerator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.text_embedding = nn.Embedding(1000, 128) | |
self.model = nn.Sequential( | |
nn.Linear(128 + 100, 256), | |
nn.LeakyReLU(0.2), | |
nn.Linear(256, 512), | |
nn.BatchNorm1d(512), | |
nn.LeakyReLU(0.2), | |
nn.Linear(512, 3*64*64), | |
nn.Tanh() | |
) | |
def forward(self, text, noise): | |
text_emb = self.text_embedding(text) | |
combined = torch.cat([text_emb, noise], 1) | |
return self.model(combined).view(-1, 3, 64, 64) | |
# ====================== | |
# Training Utilities (Unchanged) | |
# ====================== | |
def train_model(scraper, progress=gr.Progress()): | |
if len(scraper.scraped_data) == 0: | |
return "Error: No images scraped! Scrape images first." | |
dataset = TextImageDataset(scraper.scraped_data) | |
dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True) | |
generator = TextConditionedGenerator() | |
discriminator = nn.Sequential( | |
nn.Linear(3*64*64, 512), | |
nn.LeakyReLU(0.2), | |
nn.Linear(512, 1), | |
nn.Sigmoid() | |
) | |
optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"]) | |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"]) | |
criterion = nn.BCELoss() | |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])): | |
for batch in dataloader: | |
real_imgs = batch["image"] | |
text_tokens = torch.randint(0, 1000, (real_imgs.size(0),)) | |
noise = torch.randn(real_imgs.size(0), 100) | |
real_labels = torch.ones(real_imgs.size(0), 1) | |
fake_labels = torch.zeros(real_imgs.size(0), 1) | |
# Discriminator update | |
optimizer_D.zero_grad() | |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels) | |
fake_imgs = generator(text_tokens, noise) | |
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), fake_labels) | |
d_loss = (real_loss + fake_loss) / 2 | |
d_loss.backward() | |
optimizer_D.step() | |
# Generator update | |
optimizer_G.zero_grad() | |
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), real_labels) | |
g_loss.backward() | |
optimizer_G.step() | |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"]) | |
return f"Training complete! Used {len(dataset)} samples" | |
# ====================== | |
# Image Generation (Unchanged) | |
# ====================== | |
class ModelRunner: | |
def __init__(self): | |
self.pretrained_pipe = None | |
self.custom_model = None | |
def load_pretrained(self): | |
if not self.pretrained_pipe: | |
self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") | |
return self.pretrained_pipe | |
def load_custom(self): | |
if not self.custom_model: | |
model = TextConditionedGenerator() | |
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"])) | |
model.eval() | |
self.custom_model = model | |
return self.custom_model | |
def generate_image(prompt, model_type, runner): | |
if model_type == "Pretrained": | |
pipe = runner.load_pretrained() | |
image = pipe(prompt).images[0] | |
return image | |
else: | |
model = runner.load_custom() | |
noise = torch.randn(1, 100) | |
with torch.no_grad(): | |
fake = model(torch.randint(0, 1000, (1,)), noise) | |
image = fake.squeeze().permute(1, 2, 0).numpy() | |
image = (image + 1) / 2 | |
return Image.fromarray((image * 255).astype(np.uint8)) | |
# ====================== | |
# Gradio Interface (Unchanged) | |
# ====================== | |
def create_interface(): | |
with gr.Blocks() as app: | |
scraper = gr.State(WebScraper()) | |
runner = gr.State(ModelRunner()) | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox(label="Search Query") | |
scrape_btn = gr.Button("Start Scraping") | |
scrape_status = gr.Textbox(label="Scraping Status") | |
train_btn = gr.Button("Start Training") | |
training_status = gr.Textbox(label="Training Status") | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Generation Prompt") | |
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained") | |
generate_btn = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
scrape_btn.click(lambda s, q: s.start_scraping(q), [scraper, query_input], [scrape_status]) | |
train_btn.click(lambda s: train_model(s), [scraper], [training_status]) | |
generate_btn.click(lambda p, m, r: generate_image(p, m, r), [prompt_input, model_choice, runner], [output_image]) | |
return app | |
# ====================== | |
# Launch | |
# ====================== | |
app = create_interface() | |
app.launch() |