Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -50,13 +50,13 @@ class WebScraper:
|
|
50 |
self.scraping_progress = 0
|
51 |
self.scraped_count = 0
|
52 |
self.total_images = 0
|
53 |
-
|
54 |
def __getstate__(self):
|
55 |
state = self.__dict__.copy()
|
56 |
del state['stop_event']
|
57 |
del state['_lock']
|
58 |
return state
|
59 |
-
|
60 |
def __setstate__(self, state):
|
61 |
self.__dict__.update(state)
|
62 |
self.stop_event = threading.Event()
|
@@ -72,37 +72,32 @@ class WebScraper:
|
|
72 |
soup = BeautifulSoup(response.content, 'html.parser')
|
73 |
img_tags = soup.find_all('img', {'class': 'photo-item__img'})
|
74 |
self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"])
|
75 |
-
|
76 |
-
for idx, img in enumerate(img_tags[:
|
77 |
if self.stop_event.is_set():
|
78 |
break
|
79 |
-
|
80 |
img_url = img['src']
|
81 |
try:
|
82 |
img_data = requests.get(img_url).content
|
83 |
img_name = f"{int(time.time())}_{idx}.jpg"
|
|
|
84 |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
|
85 |
-
|
86 |
with open(img_path, 'wb') as f:
|
87 |
f.write(img_data)
|
88 |
-
|
89 |
self.scraped_data.append({"text": query, "image": img_path})
|
90 |
self.scraped_count = idx + 1
|
91 |
self.scraping_progress = (idx + 1) / self.total_images * 100
|
92 |
-
|
93 |
except Exception as e:
|
94 |
print(f"Error downloading image: {e}")
|
95 |
-
|
96 |
-
time.sleep(0.1) # Simulate download time
|
97 |
-
|
98 |
except Exception as e:
|
99 |
print(f"Scraping error: {e}")
|
100 |
finally:
|
101 |
self.scraping_progress = 100
|
102 |
-
|
103 |
def start_scraping(self, query):
|
104 |
self.stop_event.clear()
|
105 |
-
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
|
106 |
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
107 |
thread.start()
|
108 |
return "Scraping started..."
|
@@ -113,10 +108,10 @@ class WebScraper:
|
|
113 |
class TextImageDataset(Dataset):
|
114 |
def __init__(self, data):
|
115 |
self.data = data
|
116 |
-
|
117 |
def __len__(self):
|
118 |
return len(self.data)
|
119 |
-
|
120 |
def __getitem__(self, idx):
|
121 |
item = self.data[idx]
|
122 |
image = Image.open(item["image"]).convert('RGB')
|
@@ -136,7 +131,7 @@ class TextConditionedGenerator(nn.Module):
|
|
136 |
nn.Linear(512, 3*64*64),
|
137 |
nn.Tanh()
|
138 |
)
|
139 |
-
|
140 |
def forward(self, text, noise):
|
141 |
text_emb = self.text_embedding(text)
|
142 |
combined = torch.cat([text_emb, noise], 1)
|
@@ -148,12 +143,10 @@ class TextConditionedGenerator(nn.Module):
|
|
148 |
def train_model(scraper, progress=gr.Progress()):
|
149 |
if len(scraper.scraped_data) == 0:
|
150 |
return "Error: No images scraped! Scrape images first."
|
151 |
-
|
152 |
dataset = TextImageDataset(scraper.scraped_data)
|
153 |
-
dataloader = DataLoader(dataset,
|
154 |
-
|
155 |
-
shuffle=True)
|
156 |
-
|
157 |
generator = TextConditionedGenerator()
|
158 |
discriminator = nn.Sequential(
|
159 |
nn.Linear(3*64*64, 512),
|
@@ -161,19 +154,18 @@ def train_model(scraper, progress=gr.Progress()):
|
|
161 |
nn.Linear(512, 1),
|
162 |
nn.Sigmoid()
|
163 |
)
|
164 |
-
|
165 |
optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"])
|
166 |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
167 |
criterion = nn.BCELoss()
|
168 |
-
|
169 |
total_batches = len(dataloader)
|
170 |
-
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])
|
171 |
for batch_idx, batch in enumerate(dataloader):
|
172 |
-
real_imgs = torch.randn(4, 3, 64, 64)
|
173 |
real_labels = torch.ones(real_imgs.size(0), 1)
|
174 |
noise = torch.randn(real_imgs.size(0), 100)
|
175 |
-
|
176 |
-
# Train discriminator
|
177 |
optimizer_D.zero_grad()
|
178 |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
|
179 |
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
@@ -181,130 +173,81 @@ def train_model(scraper, progress=gr.Progress()):
|
|
181 |
d_loss = (real_loss + fake_loss) / 2
|
182 |
d_loss.backward()
|
183 |
optimizer_D.step()
|
184 |
-
|
185 |
-
# Train generator
|
186 |
optimizer_G.zero_grad()
|
187 |
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
|
188 |
g_loss.backward()
|
189 |
optimizer_G.step()
|
190 |
-
|
191 |
-
progress(
|
192 |
-
(epoch + (batch_idx+1)/total_batches) / CONFIG["training"]["epochs"],
|
193 |
-
desc=f"Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches}",
|
194 |
-
unit="epoch"
|
195 |
-
)
|
196 |
-
|
197 |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
198 |
return f"Training complete! Used {len(dataset)} samples"
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
# ======================
|
201 |
# Gradio Interface
|
202 |
# ======================
|
203 |
def create_interface():
|
204 |
with gr.Blocks() as app:
|
205 |
-
scraper = gr.State(
|
206 |
-
|
207 |
-
|
208 |
with gr.Row():
|
209 |
with gr.Column():
|
210 |
query_input = gr.Textbox(label="Search Query")
|
211 |
scrape_btn = gr.Button("Start Scraping")
|
212 |
scrape_status = gr.Textbox(label="Scraping Status")
|
213 |
-
|
214 |
-
|
215 |
train_btn = gr.Button("Start Training")
|
216 |
training_status = gr.Textbox(label="Training Status")
|
217 |
-
|
218 |
-
|
219 |
with gr.Column():
|
220 |
prompt_input = gr.Textbox(label="Generation Prompt")
|
221 |
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
|
222 |
generate_btn = gr.Button("Generate Image")
|
223 |
output_image = gr.Image(label="Generated Image")
|
224 |
-
|
225 |
-
# Scraping monitoring
|
226 |
-
def monitor_scraping(scraper):
|
227 |
-
while True:
|
228 |
-
if hasattr(scraper, 'scraping_progress'):
|
229 |
-
yield f"{scraper.scraping_progress:.1f}% ({scraper.scraped_count}/{scraper.total_images})"
|
230 |
-
else:
|
231 |
-
yield "0% (0/0)"
|
232 |
-
time.sleep(CONFIG["scraping"]["progress_interval"])
|
233 |
-
|
234 |
-
# Training monitoring
|
235 |
-
def monitor_training():
|
236 |
-
while True:
|
237 |
-
if os.path.exists(CONFIG["paths"]["model_save"]):
|
238 |
-
with open(CONFIG["paths"]["model_save"], 'rb') as f:
|
239 |
-
stats = os.stat(f.fileno())
|
240 |
-
yield f"Model size: {stats.st_size//1024}KB"
|
241 |
-
else:
|
242 |
-
yield "No trained model"
|
243 |
-
time.sleep(1)
|
244 |
-
|
245 |
-
app.load(
|
246 |
-
monitor_scraping,
|
247 |
-
inputs=[scraper],
|
248 |
-
outputs=[scraping_progress],
|
249 |
-
every=CONFIG["scraping"]["progress_interval"]
|
250 |
-
)
|
251 |
-
|
252 |
-
app.load(
|
253 |
-
monitor_training,
|
254 |
-
outputs=[training_progress],
|
255 |
-
every=1
|
256 |
-
)
|
257 |
-
|
258 |
-
# Event handlers
|
259 |
-
scrape_btn.click(
|
260 |
-
lambda s, q: s.start_scraping(q),
|
261 |
-
[scraper, query_input],
|
262 |
-
[scrape_status]
|
263 |
-
)
|
264 |
-
|
265 |
-
train_btn.click(
|
266 |
-
lambda s: train_model(s),
|
267 |
-
[scraper],
|
268 |
-
[training_status]
|
269 |
-
)
|
270 |
-
|
271 |
-
generate_btn.click(
|
272 |
-
lambda p, m, r: generate_image(p, m, r),
|
273 |
-
[prompt_input, model_choice, model_runner],
|
274 |
-
[output_image]
|
275 |
-
)
|
276 |
-
|
277 |
-
return app
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
image = pipe(prompt).images[0]
|
283 |
-
else:
|
284 |
-
model = runner.load_custom()
|
285 |
-
noise = torch.randn(1, 100)
|
286 |
-
with torch.no_grad():
|
287 |
-
fake = model(torch.randint(0, 1000, (1,)), noise)
|
288 |
-
image = fake.squeeze().permute(1, 2, 0).numpy()
|
289 |
-
image = (image + 1) / 2
|
290 |
-
return Image.fromarray((image * 255).astype(np.uint8))
|
291 |
|
292 |
-
|
293 |
-
def __init__(self):
|
294 |
-
self.pretrained_pipe = None
|
295 |
-
|
296 |
-
def load_pretrained(self):
|
297 |
-
if not self.pretrained_pipe:
|
298 |
-
self.pretrained_pipe = DiffusionPipeline.from_pretrained(
|
299 |
-
"stabilityai/stable-diffusion-xl-base-1.0"
|
300 |
-
)
|
301 |
-
return self.pretrained_pipe
|
302 |
-
|
303 |
-
def load_custom(self):
|
304 |
-
model = TextConditionedGenerator()
|
305 |
-
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
|
306 |
-
return model
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
50 |
self.scraping_progress = 0
|
51 |
self.scraped_count = 0
|
52 |
self.total_images = 0
|
53 |
+
|
54 |
def __getstate__(self):
|
55 |
state = self.__dict__.copy()
|
56 |
del state['stop_event']
|
57 |
del state['_lock']
|
58 |
return state
|
59 |
+
|
60 |
def __setstate__(self, state):
|
61 |
self.__dict__.update(state)
|
62 |
self.stop_event = threading.Event()
|
|
|
72 |
soup = BeautifulSoup(response.content, 'html.parser')
|
73 |
img_tags = soup.find_all('img', {'class': 'photo-item__img'})
|
74 |
self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"])
|
75 |
+
|
76 |
+
for idx, img in enumerate(img_tags[:self.total_images]):
|
77 |
if self.stop_event.is_set():
|
78 |
break
|
79 |
+
|
80 |
img_url = img['src']
|
81 |
try:
|
82 |
img_data = requests.get(img_url).content
|
83 |
img_name = f"{int(time.time())}_{idx}.jpg"
|
84 |
+
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
|
85 |
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
|
|
|
86 |
with open(img_path, 'wb') as f:
|
87 |
f.write(img_data)
|
|
|
88 |
self.scraped_data.append({"text": query, "image": img_path})
|
89 |
self.scraped_count = idx + 1
|
90 |
self.scraping_progress = (idx + 1) / self.total_images * 100
|
|
|
91 |
except Exception as e:
|
92 |
print(f"Error downloading image: {e}")
|
93 |
+
time.sleep(0.1)
|
|
|
|
|
94 |
except Exception as e:
|
95 |
print(f"Scraping error: {e}")
|
96 |
finally:
|
97 |
self.scraping_progress = 100
|
98 |
+
|
99 |
def start_scraping(self, query):
|
100 |
self.stop_event.clear()
|
|
|
101 |
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
102 |
thread.start()
|
103 |
return "Scraping started..."
|
|
|
108 |
class TextImageDataset(Dataset):
|
109 |
def __init__(self, data):
|
110 |
self.data = data
|
111 |
+
|
112 |
def __len__(self):
|
113 |
return len(self.data)
|
114 |
+
|
115 |
def __getitem__(self, idx):
|
116 |
item = self.data[idx]
|
117 |
image = Image.open(item["image"]).convert('RGB')
|
|
|
131 |
nn.Linear(512, 3*64*64),
|
132 |
nn.Tanh()
|
133 |
)
|
134 |
+
|
135 |
def forward(self, text, noise):
|
136 |
text_emb = self.text_embedding(text)
|
137 |
combined = torch.cat([text_emb, noise], 1)
|
|
|
143 |
def train_model(scraper, progress=gr.Progress()):
|
144 |
if len(scraper.scraped_data) == 0:
|
145 |
return "Error: No images scraped! Scrape images first."
|
146 |
+
|
147 |
dataset = TextImageDataset(scraper.scraped_data)
|
148 |
+
dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True)
|
149 |
+
|
|
|
|
|
150 |
generator = TextConditionedGenerator()
|
151 |
discriminator = nn.Sequential(
|
152 |
nn.Linear(3*64*64, 512),
|
|
|
154 |
nn.Linear(512, 1),
|
155 |
nn.Sigmoid()
|
156 |
)
|
157 |
+
|
158 |
optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"])
|
159 |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
160 |
criterion = nn.BCELoss()
|
161 |
+
|
162 |
total_batches = len(dataloader)
|
163 |
+
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])):
|
164 |
for batch_idx, batch in enumerate(dataloader):
|
165 |
+
real_imgs = torch.randn(4, 3, 64, 64)
|
166 |
real_labels = torch.ones(real_imgs.size(0), 1)
|
167 |
noise = torch.randn(real_imgs.size(0), 100)
|
168 |
+
|
|
|
169 |
optimizer_D.zero_grad()
|
170 |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
|
171 |
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
|
|
173 |
d_loss = (real_loss + fake_loss) / 2
|
174 |
d_loss.backward()
|
175 |
optimizer_D.step()
|
176 |
+
|
|
|
177 |
optimizer_G.zero_grad()
|
178 |
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
|
179 |
g_loss.backward()
|
180 |
optimizer_G.step()
|
181 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
183 |
return f"Training complete! Used {len(dataset)} samples"
|
184 |
|
185 |
+
# ======================
|
186 |
+
# Image Generation
|
187 |
+
# ======================
|
188 |
+
class ModelRunner:
|
189 |
+
def __init__(self):
|
190 |
+
self.pretrained_pipe = None
|
191 |
+
self.custom_model = None
|
192 |
+
|
193 |
+
def load_pretrained(self):
|
194 |
+
if not self.pretrained_pipe:
|
195 |
+
self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
196 |
+
return self.pretrained_pipe
|
197 |
+
|
198 |
+
def load_custom(self):
|
199 |
+
if not self.custom_model:
|
200 |
+
model = TextConditionedGenerator()
|
201 |
+
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"]))
|
202 |
+
model.eval()
|
203 |
+
self.custom_model = model
|
204 |
+
return self.custom_model
|
205 |
+
|
206 |
+
def generate_image(prompt, model_type, runner):
|
207 |
+
if model_type == "Pretrained":
|
208 |
+
pipe = runner.load_pretrained()
|
209 |
+
image = pipe(prompt).images[0]
|
210 |
+
return image
|
211 |
+
else:
|
212 |
+
model = runner.load_custom()
|
213 |
+
noise = torch.randn(1, 100)
|
214 |
+
with torch.no_grad():
|
215 |
+
fake = model(torch.randint(0, 1000, (1,)), noise)
|
216 |
+
image = fake.squeeze().permute(1, 2, 0).numpy()
|
217 |
+
image = (image + 1) / 2
|
218 |
+
return Image.fromarray((image * 255).astype(np.uint8))
|
219 |
+
|
220 |
# ======================
|
221 |
# Gradio Interface
|
222 |
# ======================
|
223 |
def create_interface():
|
224 |
with gr.Blocks() as app:
|
225 |
+
scraper = gr.State(WebScraper())
|
226 |
+
runner = gr.State(ModelRunner())
|
227 |
+
|
228 |
with gr.Row():
|
229 |
with gr.Column():
|
230 |
query_input = gr.Textbox(label="Search Query")
|
231 |
scrape_btn = gr.Button("Start Scraping")
|
232 |
scrape_status = gr.Textbox(label="Scraping Status")
|
233 |
+
|
|
|
234 |
train_btn = gr.Button("Start Training")
|
235 |
training_status = gr.Textbox(label="Training Status")
|
236 |
+
|
|
|
237 |
with gr.Column():
|
238 |
prompt_input = gr.Textbox(label="Generation Prompt")
|
239 |
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
|
240 |
generate_btn = gr.Button("Generate Image")
|
241 |
output_image = gr.Image(label="Generated Image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
+
scrape_btn.click(lambda s, q: s.start_scraping(q), [scraper, query_input], [scrape_status])
|
244 |
+
train_btn.click(lambda s: train_model(s), [scraper], [training_status])
|
245 |
+
generate_btn.click(lambda p, m, r: generate_image(p, m, r), [prompt_input, model_choice, runner], [output_image])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
+
return app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
+
# ======================
|
250 |
+
# Launch
|
251 |
+
# ======================
|
252 |
+
app = create_interface()
|
253 |
+
app.launch()
|