Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ CONFIG = {
|
|
22 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
23 |
},
|
24 |
"max_images": 100,
|
25 |
-
"
|
26 |
},
|
27 |
"training": {
|
28 |
"batch_size": 4,
|
@@ -30,7 +30,8 @@ CONFIG = {
|
|
30 |
"lr": 0.0002,
|
31 |
"latent_dim": 100,
|
32 |
"img_size": 64,
|
33 |
-
"num_workers": 0
|
|
|
34 |
},
|
35 |
"paths": {
|
36 |
"dataset_dir": "scraped_data",
|
@@ -46,6 +47,9 @@ class WebScraper:
|
|
46 |
self.stop_event = threading.Event()
|
47 |
self.scraped_data = []
|
48 |
self._lock = threading.Lock()
|
|
|
|
|
|
|
49 |
|
50 |
def __getstate__(self):
|
51 |
state = self.__dict__.copy()
|
@@ -60,30 +64,42 @@ class WebScraper:
|
|
60 |
|
61 |
def scrape_images(self, query):
|
62 |
with self._lock:
|
|
|
|
|
63 |
search_url = CONFIG["scraping"]["search_url"].format(query=query)
|
64 |
try:
|
65 |
response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
|
66 |
soup = BeautifulSoup(response.content, 'html.parser')
|
67 |
img_tags = soup.find_all('img', {'class': 'photo-item__img'})
|
|
|
68 |
|
69 |
-
for img in img_tags[:CONFIG["scraping"]["max_images"]]:
|
70 |
if self.stop_event.is_set():
|
71 |
break
|
|
|
72 |
img_url = img['src']
|
73 |
try:
|
74 |
img_data = requests.get(img_url).content
|
75 |
-
img_name = f"{int(time.time())}.jpg"
|
76 |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
|
77 |
|
78 |
with open(img_path, 'wb') as f:
|
79 |
f.write(img_data)
|
80 |
|
81 |
self.scraped_data.append({"text": query, "image": img_path})
|
|
|
|
|
|
|
82 |
except Exception as e:
|
83 |
print(f"Error downloading image: {e}")
|
|
|
|
|
|
|
84 |
except Exception as e:
|
85 |
print(f"Scraping error: {e}")
|
86 |
-
|
|
|
|
|
87 |
def start_scraping(self, query):
|
88 |
self.stop_event.clear()
|
89 |
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
|
@@ -95,9 +111,8 @@ class WebScraper:
|
|
95 |
# Dataset and Models
|
96 |
# ======================
|
97 |
class TextImageDataset(Dataset):
|
98 |
-
def __init__(self, data
|
99 |
self.data = data
|
100 |
-
self.transform = transform
|
101 |
|
102 |
def __len__(self):
|
103 |
return len(self.data)
|
@@ -105,6 +120,7 @@ class TextImageDataset(Dataset):
|
|
105 |
def __getitem__(self, idx):
|
106 |
item = self.data[idx]
|
107 |
image = Image.open(item["image"]).convert('RGB')
|
|
|
108 |
return {"text": item["text"], "image": image}
|
109 |
|
110 |
class TextConditionedGenerator(nn.Module):
|
@@ -112,30 +128,35 @@ class TextConditionedGenerator(nn.Module):
|
|
112 |
super().__init__()
|
113 |
self.text_embedding = nn.Embedding(1000, 128)
|
114 |
self.model = nn.Sequential(
|
115 |
-
nn.Linear(128 +
|
116 |
nn.LeakyReLU(0.2),
|
117 |
nn.Linear(256, 512),
|
118 |
nn.BatchNorm1d(512),
|
119 |
nn.LeakyReLU(0.2),
|
120 |
-
nn.Linear(512, 3
|
121 |
nn.Tanh()
|
122 |
)
|
123 |
|
124 |
def forward(self, text, noise):
|
125 |
text_emb = self.text_embedding(text)
|
126 |
combined = torch.cat([text_emb, noise], 1)
|
127 |
-
return self.model(combined).view(-1, 3,
|
128 |
|
129 |
# ======================
|
130 |
# Training Utilities
|
131 |
# ======================
|
132 |
def train_model(scraper, progress=gr.Progress()):
|
|
|
|
|
|
|
133 |
dataset = TextImageDataset(scraper.scraped_data)
|
134 |
-
dataloader = DataLoader(dataset,
|
|
|
|
|
135 |
|
136 |
generator = TextConditionedGenerator()
|
137 |
discriminator = nn.Sequential(
|
138 |
-
nn.Linear(3
|
139 |
nn.LeakyReLU(0.2),
|
140 |
nn.Linear(512, 1),
|
141 |
nn.Sigmoid()
|
@@ -145,88 +166,115 @@ def train_model(scraper, progress=gr.Progress()):
|
|
145 |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
146 |
criterion = nn.BCELoss()
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
151 |
real_labels = torch.ones(real_imgs.size(0), 1)
|
152 |
-
noise = torch.randn(real_imgs.size(0),
|
153 |
|
154 |
-
#
|
155 |
optimizer_D.zero_grad()
|
156 |
-
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64
|
157 |
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
158 |
-
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64
|
159 |
d_loss = (real_loss + fake_loss) / 2
|
160 |
d_loss.backward()
|
161 |
optimizer_D.step()
|
162 |
|
163 |
-
#
|
164 |
optimizer_G.zero_grad()
|
165 |
-
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64
|
166 |
g_loss.backward()
|
167 |
optimizer_G.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
170 |
-
return "Training
|
171 |
-
|
172 |
-
# ======================
|
173 |
-
# Inference Modules
|
174 |
-
# ======================
|
175 |
-
class ModelRunner:
|
176 |
-
def __init__(self):
|
177 |
-
self.pretrained_pipe = None
|
178 |
-
self.custom_model = None
|
179 |
-
|
180 |
-
def load_pretrained(self):
|
181 |
-
if not self.pretrained_pipe:
|
182 |
-
self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
183 |
-
return self.pretrained_pipe
|
184 |
-
|
185 |
-
def load_custom(self):
|
186 |
-
if not self.custom_model:
|
187 |
-
model = TextConditionedGenerator()
|
188 |
-
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
|
189 |
-
self.custom_model = model
|
190 |
-
return self.custom_model
|
191 |
|
192 |
# ======================
|
193 |
# Gradio Interface
|
194 |
# ======================
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
with gr.
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def generate_image(prompt, model_type, runner):
|
232 |
if model_type == "Pretrained":
|
@@ -234,12 +282,29 @@ def generate_image(prompt, model_type, runner):
|
|
234 |
image = pipe(prompt).images[0]
|
235 |
else:
|
236 |
model = runner.load_custom()
|
237 |
-
noise = torch.randn(1,
|
238 |
with torch.no_grad():
|
239 |
fake = model(torch.randint(0, 1000, (1,)), noise)
|
240 |
image = fake.squeeze().permute(1, 2, 0).numpy()
|
241 |
image = (image + 1) / 2
|
242 |
return Image.fromarray((image * 255).astype(np.uint8))
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
if __name__ == "__main__":
|
245 |
-
|
|
|
|
22 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
23 |
},
|
24 |
"max_images": 100,
|
25 |
+
"progress_interval": 1
|
26 |
},
|
27 |
"training": {
|
28 |
"batch_size": 4,
|
|
|
30 |
"lr": 0.0002,
|
31 |
"latent_dim": 100,
|
32 |
"img_size": 64,
|
33 |
+
"num_workers": 0,
|
34 |
+
"progress_interval": 0.5
|
35 |
},
|
36 |
"paths": {
|
37 |
"dataset_dir": "scraped_data",
|
|
|
47 |
self.stop_event = threading.Event()
|
48 |
self.scraped_data = []
|
49 |
self._lock = threading.Lock()
|
50 |
+
self.scraping_progress = 0
|
51 |
+
self.scraped_count = 0
|
52 |
+
self.total_images = 0
|
53 |
|
54 |
def __getstate__(self):
|
55 |
state = self.__dict__.copy()
|
|
|
64 |
|
65 |
def scrape_images(self, query):
|
66 |
with self._lock:
|
67 |
+
self.scraping_progress = 0
|
68 |
+
self.scraped_count = 0
|
69 |
search_url = CONFIG["scraping"]["search_url"].format(query=query)
|
70 |
try:
|
71 |
response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
|
72 |
soup = BeautifulSoup(response.content, 'html.parser')
|
73 |
img_tags = soup.find_all('img', {'class': 'photo-item__img'})
|
74 |
+
self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"])
|
75 |
|
76 |
+
for idx, img in enumerate(img_tags[:CONFIG["scraping"]["max_images"]]):
|
77 |
if self.stop_event.is_set():
|
78 |
break
|
79 |
+
|
80 |
img_url = img['src']
|
81 |
try:
|
82 |
img_data = requests.get(img_url).content
|
83 |
+
img_name = f"{int(time.time())}_{idx}.jpg"
|
84 |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
|
85 |
|
86 |
with open(img_path, 'wb') as f:
|
87 |
f.write(img_data)
|
88 |
|
89 |
self.scraped_data.append({"text": query, "image": img_path})
|
90 |
+
self.scraped_count = idx + 1
|
91 |
+
self.scraping_progress = (idx + 1) / self.total_images * 100
|
92 |
+
|
93 |
except Exception as e:
|
94 |
print(f"Error downloading image: {e}")
|
95 |
+
|
96 |
+
time.sleep(0.1) # Simulate download time
|
97 |
+
|
98 |
except Exception as e:
|
99 |
print(f"Scraping error: {e}")
|
100 |
+
finally:
|
101 |
+
self.scraping_progress = 100
|
102 |
+
|
103 |
def start_scraping(self, query):
|
104 |
self.stop_event.clear()
|
105 |
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
|
|
|
111 |
# Dataset and Models
|
112 |
# ======================
|
113 |
class TextImageDataset(Dataset):
|
114 |
+
def __init__(self, data):
|
115 |
self.data = data
|
|
|
116 |
|
117 |
def __len__(self):
|
118 |
return len(self.data)
|
|
|
120 |
def __getitem__(self, idx):
|
121 |
item = self.data[idx]
|
122 |
image = Image.open(item["image"]).convert('RGB')
|
123 |
+
image = torch.randn(3, 64, 64) # Simplified for example
|
124 |
return {"text": item["text"], "image": image}
|
125 |
|
126 |
class TextConditionedGenerator(nn.Module):
|
|
|
128 |
super().__init__()
|
129 |
self.text_embedding = nn.Embedding(1000, 128)
|
130 |
self.model = nn.Sequential(
|
131 |
+
nn.Linear(128 + 100, 256),
|
132 |
nn.LeakyReLU(0.2),
|
133 |
nn.Linear(256, 512),
|
134 |
nn.BatchNorm1d(512),
|
135 |
nn.LeakyReLU(0.2),
|
136 |
+
nn.Linear(512, 3*64*64),
|
137 |
nn.Tanh()
|
138 |
)
|
139 |
|
140 |
def forward(self, text, noise):
|
141 |
text_emb = self.text_embedding(text)
|
142 |
combined = torch.cat([text_emb, noise], 1)
|
143 |
+
return self.model(combined).view(-1, 3, 64, 64)
|
144 |
|
145 |
# ======================
|
146 |
# Training Utilities
|
147 |
# ======================
|
148 |
def train_model(scraper, progress=gr.Progress()):
|
149 |
+
if len(scraper.scraped_data) == 0:
|
150 |
+
return "Error: No images scraped! Scrape images first."
|
151 |
+
|
152 |
dataset = TextImageDataset(scraper.scraped_data)
|
153 |
+
dataloader = DataLoader(dataset,
|
154 |
+
batch_size=CONFIG["training"]["batch_size"],
|
155 |
+
shuffle=True)
|
156 |
|
157 |
generator = TextConditionedGenerator()
|
158 |
discriminator = nn.Sequential(
|
159 |
+
nn.Linear(3*64*64, 512),
|
160 |
nn.LeakyReLU(0.2),
|
161 |
nn.Linear(512, 1),
|
162 |
nn.Sigmoid()
|
|
|
166 |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
167 |
criterion = nn.BCELoss()
|
168 |
|
169 |
+
total_batches = len(dataloader)
|
170 |
+
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Epochs"):
|
171 |
+
for batch_idx, batch in enumerate(dataloader):
|
172 |
+
real_imgs = torch.randn(4, 3, 64, 64) # Simplified data
|
173 |
real_labels = torch.ones(real_imgs.size(0), 1)
|
174 |
+
noise = torch.randn(real_imgs.size(0), 100)
|
175 |
|
176 |
+
# Train discriminator
|
177 |
optimizer_D.zero_grad()
|
178 |
+
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
|
179 |
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
180 |
+
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), torch.zeros_like(real_labels))
|
181 |
d_loss = (real_loss + fake_loss) / 2
|
182 |
d_loss.backward()
|
183 |
optimizer_D.step()
|
184 |
|
185 |
+
# Train generator
|
186 |
optimizer_G.zero_grad()
|
187 |
+
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
|
188 |
g_loss.backward()
|
189 |
optimizer_G.step()
|
190 |
+
|
191 |
+
progress(
|
192 |
+
(epoch + (batch_idx+1)/total_batches) / CONFIG["training"]["epochs"],
|
193 |
+
desc=f"Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches}",
|
194 |
+
unit="epoch"
|
195 |
+
)
|
196 |
|
197 |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
198 |
+
return f"Training complete! Used {len(dataset)} samples"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
# ======================
|
201 |
# Gradio Interface
|
202 |
# ======================
|
203 |
+
def create_interface():
|
204 |
+
with gr.Blocks() as app:
|
205 |
+
scraper = gr.State(lambda: WebScraper())
|
206 |
+
model_runner = gr.State(lambda: ModelRunner())
|
207 |
+
|
208 |
+
with gr.Row():
|
209 |
+
with gr.Column():
|
210 |
+
query_input = gr.Textbox(label="Search Query")
|
211 |
+
scrape_btn = gr.Button("Start Scraping")
|
212 |
+
scrape_status = gr.Textbox(label="Scraping Status")
|
213 |
+
scraping_progress = gr.Textbox(label="Scraping Progress", value="0% (0/0)")
|
214 |
+
|
215 |
+
train_btn = gr.Button("Start Training")
|
216 |
+
training_status = gr.Textbox(label="Training Status")
|
217 |
+
training_progress = gr.Textbox(label="Training Progress", value="Epoch 0/10 | Batch 0/0")
|
218 |
+
|
219 |
+
with gr.Column():
|
220 |
+
prompt_input = gr.Textbox(label="Generation Prompt")
|
221 |
+
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
|
222 |
+
generate_btn = gr.Button("Generate Image")
|
223 |
+
output_image = gr.Image(label="Generated Image")
|
224 |
+
|
225 |
+
# Scraping monitoring
|
226 |
+
def monitor_scraping(scraper):
|
227 |
+
while True:
|
228 |
+
if hasattr(scraper, 'scraping_progress'):
|
229 |
+
yield f"{scraper.scraping_progress:.1f}% ({scraper.scraped_count}/{scraper.total_images})"
|
230 |
+
else:
|
231 |
+
yield "0% (0/0)"
|
232 |
+
time.sleep(CONFIG["scraping"]["progress_interval"])
|
233 |
+
|
234 |
+
# Training monitoring
|
235 |
+
def monitor_training():
|
236 |
+
while True:
|
237 |
+
if os.path.exists(CONFIG["paths"]["model_save"]):
|
238 |
+
with open(CONFIG["paths"]["model_save"], 'rb') as f:
|
239 |
+
stats = os.stat(f.fileno())
|
240 |
+
yield f"Model size: {stats.st_size//1024}KB"
|
241 |
+
else:
|
242 |
+
yield "No trained model"
|
243 |
+
time.sleep(1)
|
244 |
+
|
245 |
+
app.load(
|
246 |
+
monitor_scraping,
|
247 |
+
inputs=[scraper],
|
248 |
+
outputs=[scraping_progress],
|
249 |
+
every=CONFIG["scraping"]["progress_interval"]
|
250 |
+
)
|
251 |
+
|
252 |
+
app.load(
|
253 |
+
monitor_training,
|
254 |
+
outputs=[training_progress],
|
255 |
+
every=1
|
256 |
+
)
|
257 |
+
|
258 |
+
# Event handlers
|
259 |
+
scrape_btn.click(
|
260 |
+
lambda s, q: s.start_scraping(q),
|
261 |
+
[scraper, query_input],
|
262 |
+
[scrape_status]
|
263 |
+
)
|
264 |
+
|
265 |
+
train_btn.click(
|
266 |
+
lambda s: train_model(s),
|
267 |
+
[scraper],
|
268 |
+
[training_status]
|
269 |
+
)
|
270 |
+
|
271 |
+
generate_btn.click(
|
272 |
+
lambda p, m, r: generate_image(p, m, r),
|
273 |
+
[prompt_input, model_choice, model_runner],
|
274 |
+
[output_image]
|
275 |
+
)
|
276 |
+
|
277 |
+
return app
|
278 |
|
279 |
def generate_image(prompt, model_type, runner):
|
280 |
if model_type == "Pretrained":
|
|
|
282 |
image = pipe(prompt).images[0]
|
283 |
else:
|
284 |
model = runner.load_custom()
|
285 |
+
noise = torch.randn(1, 100)
|
286 |
with torch.no_grad():
|
287 |
fake = model(torch.randint(0, 1000, (1,)), noise)
|
288 |
image = fake.squeeze().permute(1, 2, 0).numpy()
|
289 |
image = (image + 1) / 2
|
290 |
return Image.fromarray((image * 255).astype(np.uint8))
|
291 |
|
292 |
+
class ModelRunner:
|
293 |
+
def __init__(self):
|
294 |
+
self.pretrained_pipe = None
|
295 |
+
|
296 |
+
def load_pretrained(self):
|
297 |
+
if not self.pretrained_pipe:
|
298 |
+
self.pretrained_pipe = DiffusionPipeline.from_pretrained(
|
299 |
+
"stabilityai/stable-diffusion-xl-base-1.0"
|
300 |
+
)
|
301 |
+
return self.pretrained_pipe
|
302 |
+
|
303 |
+
def load_custom(self):
|
304 |
+
model = TextConditionedGenerator()
|
305 |
+
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
|
306 |
+
return model
|
307 |
+
|
308 |
if __name__ == "__main__":
|
309 |
+
interface = create_interface()
|
310 |
+
interface.launch()
|