Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import torch.optim as optim
|
|
5 |
from torch.utils.data import Dataset, DataLoader
|
6 |
from diffusers import DiffusionPipeline
|
7 |
import requests
|
8 |
-
from bs4 import BeautifulSoup
|
9 |
import os
|
10 |
import time
|
11 |
import threading
|
@@ -16,11 +15,9 @@ import numpy as np
|
|
16 |
# Configuration
|
17 |
# ======================
|
18 |
CONFIG = {
|
|
|
19 |
"scraping": {
|
20 |
-
"search_url": "https://
|
21 |
-
"headers": {
|
22 |
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
|
23 |
-
},
|
24 |
"max_images": 100,
|
25 |
"progress_interval": 1
|
26 |
},
|
@@ -40,7 +37,7 @@ CONFIG = {
|
|
40 |
}
|
41 |
|
42 |
# ======================
|
43 |
-
# Web Scraping Module
|
44 |
# ======================
|
45 |
class WebScraper:
|
46 |
def __init__(self):
|
@@ -66,21 +63,22 @@ class WebScraper:
|
|
66 |
with self._lock:
|
67 |
self.scraping_progress = 0
|
68 |
self.scraped_count = 0
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
try:
|
71 |
-
response = requests.get(
|
72 |
-
|
73 |
-
|
74 |
-
self.total_images = min(len(
|
75 |
|
76 |
-
for idx,
|
77 |
if self.stop_event.is_set():
|
78 |
break
|
79 |
|
80 |
-
img_url =
|
81 |
-
if not img_url:
|
82 |
-
continue
|
83 |
-
|
84 |
try:
|
85 |
img_data = requests.get(img_url).content
|
86 |
img_name = f"{int(time.time())}_{idx}.jpg"
|
@@ -95,7 +93,7 @@ class WebScraper:
|
|
95 |
print(f"Error downloading image: {e}")
|
96 |
time.sleep(0.1)
|
97 |
except Exception as e:
|
98 |
-
print(f"
|
99 |
finally:
|
100 |
self.scraping_progress = 100
|
101 |
|
@@ -107,7 +105,7 @@ class WebScraper:
|
|
107 |
return "Scraping started..."
|
108 |
|
109 |
# ======================
|
110 |
-
# Dataset and Models
|
111 |
# ======================
|
112 |
class TextImageDataset(Dataset):
|
113 |
def __init__(self, data):
|
@@ -148,7 +146,7 @@ class TextConditionedGenerator(nn.Module):
|
|
148 |
return self.model(combined).view(-1, 3, 64, 64)
|
149 |
|
150 |
# ======================
|
151 |
-
# Training Utilities
|
152 |
# ======================
|
153 |
def train_model(scraper, progress=gr.Progress()):
|
154 |
if len(scraper.scraped_data) == 0:
|
@@ -196,7 +194,7 @@ def train_model(scraper, progress=gr.Progress()):
|
|
196 |
return f"Training complete! Used {len(dataset)} samples"
|
197 |
|
198 |
# ======================
|
199 |
-
# Image Generation
|
200 |
# ======================
|
201 |
class ModelRunner:
|
202 |
def __init__(self):
|
@@ -231,7 +229,7 @@ def generate_image(prompt, model_type, runner):
|
|
231 |
return Image.fromarray((image * 255).astype(np.uint8))
|
232 |
|
233 |
# ======================
|
234 |
-
# Gradio Interface
|
235 |
# ======================
|
236 |
def create_interface():
|
237 |
with gr.Blocks() as app:
|
|
|
5 |
from torch.utils.data import Dataset, DataLoader
|
6 |
from diffusers import DiffusionPipeline
|
7 |
import requests
|
|
|
8 |
import os
|
9 |
import time
|
10 |
import threading
|
|
|
15 |
# Configuration
|
16 |
# ======================
|
17 |
CONFIG = {
|
18 |
+
"pexels_api_key": "HSknLvmKmOXuqXsE89NXzu6ysOqPr7FmHGObjaSdhTTmpFSuK5K7OaHn",
|
19 |
"scraping": {
|
20 |
+
"search_url": "https://api.pexels.com/v1/search?query={query}&per_page=80",
|
|
|
|
|
|
|
21 |
"max_images": 100,
|
22 |
"progress_interval": 1
|
23 |
},
|
|
|
37 |
}
|
38 |
|
39 |
# ======================
|
40 |
+
# Web Scraping Module (Now using Pexels API)
|
41 |
# ======================
|
42 |
class WebScraper:
|
43 |
def __init__(self):
|
|
|
63 |
with self._lock:
|
64 |
self.scraping_progress = 0
|
65 |
self.scraped_count = 0
|
66 |
+
url = CONFIG["scraping"]["search_url"].format(query=query)
|
67 |
+
headers = {
|
68 |
+
"Authorization": CONFIG["pexels_api_key"]
|
69 |
+
}
|
70 |
+
|
71 |
try:
|
72 |
+
response = requests.get(url, headers=headers)
|
73 |
+
data = response.json()
|
74 |
+
photos = data.get("photos", [])
|
75 |
+
self.total_images = min(len(photos), CONFIG["scraping"]["max_images"])
|
76 |
|
77 |
+
for idx, photo in enumerate(photos[:self.total_images]):
|
78 |
if self.stop_event.is_set():
|
79 |
break
|
80 |
|
81 |
+
img_url = photo["src"]["large"]
|
|
|
|
|
|
|
82 |
try:
|
83 |
img_data = requests.get(img_url).content
|
84 |
img_name = f"{int(time.time())}_{idx}.jpg"
|
|
|
93 |
print(f"Error downloading image: {e}")
|
94 |
time.sleep(0.1)
|
95 |
except Exception as e:
|
96 |
+
print(f"API scraping error: {e}")
|
97 |
finally:
|
98 |
self.scraping_progress = 100
|
99 |
|
|
|
105 |
return "Scraping started..."
|
106 |
|
107 |
# ======================
|
108 |
+
# Dataset and Models (Unchanged)
|
109 |
# ======================
|
110 |
class TextImageDataset(Dataset):
|
111 |
def __init__(self, data):
|
|
|
146 |
return self.model(combined).view(-1, 3, 64, 64)
|
147 |
|
148 |
# ======================
|
149 |
+
# Training Utilities (Unchanged)
|
150 |
# ======================
|
151 |
def train_model(scraper, progress=gr.Progress()):
|
152 |
if len(scraper.scraped_data) == 0:
|
|
|
194 |
return f"Training complete! Used {len(dataset)} samples"
|
195 |
|
196 |
# ======================
|
197 |
+
# Image Generation (Unchanged)
|
198 |
# ======================
|
199 |
class ModelRunner:
|
200 |
def __init__(self):
|
|
|
229 |
return Image.fromarray((image * 255).astype(np.uint8))
|
230 |
|
231 |
# ======================
|
232 |
+
# Gradio Interface (Unchanged)
|
233 |
# ======================
|
234 |
def create_interface():
|
235 |
with gr.Blocks() as app:
|