gaur3009 commited on
Commit
5db3cdf
·
verified ·
1 Parent(s): 7461c43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -127
app.py CHANGED
@@ -50,13 +50,13 @@ class WebScraper:
50
  self.scraping_progress = 0
51
  self.scraped_count = 0
52
  self.total_images = 0
53
-
54
  def __getstate__(self):
55
  state = self.__dict__.copy()
56
  del state['stop_event']
57
  del state['_lock']
58
  return state
59
-
60
  def __setstate__(self, state):
61
  self.__dict__.update(state)
62
  self.stop_event = threading.Event()
@@ -72,37 +72,32 @@ class WebScraper:
72
  soup = BeautifulSoup(response.content, 'html.parser')
73
  img_tags = soup.find_all('img', {'class': 'photo-item__img'})
74
  self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"])
75
-
76
- for idx, img in enumerate(img_tags[:CONFIG["scraping"]["max_images"]]):
77
  if self.stop_event.is_set():
78
  break
79
-
80
  img_url = img['src']
81
  try:
82
  img_data = requests.get(img_url).content
83
  img_name = f"{int(time.time())}_{idx}.jpg"
 
84
  img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
85
-
86
  with open(img_path, 'wb') as f:
87
  f.write(img_data)
88
-
89
  self.scraped_data.append({"text": query, "image": img_path})
90
  self.scraped_count = idx + 1
91
  self.scraping_progress = (idx + 1) / self.total_images * 100
92
-
93
  except Exception as e:
94
  print(f"Error downloading image: {e}")
95
-
96
- time.sleep(0.1) # Simulate download time
97
-
98
  except Exception as e:
99
  print(f"Scraping error: {e}")
100
  finally:
101
  self.scraping_progress = 100
102
-
103
  def start_scraping(self, query):
104
  self.stop_event.clear()
105
- os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
106
  thread = threading.Thread(target=self.scrape_images, args=(query,))
107
  thread.start()
108
  return "Scraping started..."
@@ -113,10 +108,10 @@ class WebScraper:
113
  class TextImageDataset(Dataset):
114
  def __init__(self, data):
115
  self.data = data
116
-
117
  def __len__(self):
118
  return len(self.data)
119
-
120
  def __getitem__(self, idx):
121
  item = self.data[idx]
122
  image = Image.open(item["image"]).convert('RGB')
@@ -136,7 +131,7 @@ class TextConditionedGenerator(nn.Module):
136
  nn.Linear(512, 3*64*64),
137
  nn.Tanh()
138
  )
139
-
140
  def forward(self, text, noise):
141
  text_emb = self.text_embedding(text)
142
  combined = torch.cat([text_emb, noise], 1)
@@ -148,12 +143,10 @@ class TextConditionedGenerator(nn.Module):
148
  def train_model(scraper, progress=gr.Progress()):
149
  if len(scraper.scraped_data) == 0:
150
  return "Error: No images scraped! Scrape images first."
151
-
152
  dataset = TextImageDataset(scraper.scraped_data)
153
- dataloader = DataLoader(dataset,
154
- batch_size=CONFIG["training"]["batch_size"],
155
- shuffle=True)
156
-
157
  generator = TextConditionedGenerator()
158
  discriminator = nn.Sequential(
159
  nn.Linear(3*64*64, 512),
@@ -161,19 +154,18 @@ def train_model(scraper, progress=gr.Progress()):
161
  nn.Linear(512, 1),
162
  nn.Sigmoid()
163
  )
164
-
165
  optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"])
166
  optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
167
  criterion = nn.BCELoss()
168
-
169
  total_batches = len(dataloader)
170
- for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Epochs"):
171
  for batch_idx, batch in enumerate(dataloader):
172
- real_imgs = torch.randn(4, 3, 64, 64) # Simplified data
173
  real_labels = torch.ones(real_imgs.size(0), 1)
174
  noise = torch.randn(real_imgs.size(0), 100)
175
-
176
- # Train discriminator
177
  optimizer_D.zero_grad()
178
  real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
179
  fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
@@ -181,130 +173,81 @@ def train_model(scraper, progress=gr.Progress()):
181
  d_loss = (real_loss + fake_loss) / 2
182
  d_loss.backward()
183
  optimizer_D.step()
184
-
185
- # Train generator
186
  optimizer_G.zero_grad()
187
  g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
188
  g_loss.backward()
189
  optimizer_G.step()
190
-
191
- progress(
192
- (epoch + (batch_idx+1)/total_batches) / CONFIG["training"]["epochs"],
193
- desc=f"Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches}",
194
- unit="epoch"
195
- )
196
-
197
  torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
198
  return f"Training complete! Used {len(dataset)} samples"
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  # ======================
201
  # Gradio Interface
202
  # ======================
203
  def create_interface():
204
  with gr.Blocks() as app:
205
- scraper = gr.State(lambda: WebScraper())
206
- model_runner = gr.State(lambda: ModelRunner())
207
-
208
  with gr.Row():
209
  with gr.Column():
210
  query_input = gr.Textbox(label="Search Query")
211
  scrape_btn = gr.Button("Start Scraping")
212
  scrape_status = gr.Textbox(label="Scraping Status")
213
- scraping_progress = gr.Textbox(label="Scraping Progress", value="0% (0/0)")
214
-
215
  train_btn = gr.Button("Start Training")
216
  training_status = gr.Textbox(label="Training Status")
217
- training_progress = gr.Textbox(label="Training Progress", value="Epoch 0/10 | Batch 0/0")
218
-
219
  with gr.Column():
220
  prompt_input = gr.Textbox(label="Generation Prompt")
221
  model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
222
  generate_btn = gr.Button("Generate Image")
223
  output_image = gr.Image(label="Generated Image")
224
-
225
- # Scraping monitoring
226
- def monitor_scraping(scraper):
227
- while True:
228
- if hasattr(scraper, 'scraping_progress'):
229
- yield f"{scraper.scraping_progress:.1f}% ({scraper.scraped_count}/{scraper.total_images})"
230
- else:
231
- yield "0% (0/0)"
232
- time.sleep(CONFIG["scraping"]["progress_interval"])
233
-
234
- # Training monitoring
235
- def monitor_training():
236
- while True:
237
- if os.path.exists(CONFIG["paths"]["model_save"]):
238
- with open(CONFIG["paths"]["model_save"], 'rb') as f:
239
- stats = os.stat(f.fileno())
240
- yield f"Model size: {stats.st_size//1024}KB"
241
- else:
242
- yield "No trained model"
243
- time.sleep(1)
244
-
245
- app.load(
246
- monitor_scraping,
247
- inputs=[scraper],
248
- outputs=[scraping_progress],
249
- every=CONFIG["scraping"]["progress_interval"]
250
- )
251
-
252
- app.load(
253
- monitor_training,
254
- outputs=[training_progress],
255
- every=1
256
- )
257
-
258
- # Event handlers
259
- scrape_btn.click(
260
- lambda s, q: s.start_scraping(q),
261
- [scraper, query_input],
262
- [scrape_status]
263
- )
264
-
265
- train_btn.click(
266
- lambda s: train_model(s),
267
- [scraper],
268
- [training_status]
269
- )
270
-
271
- generate_btn.click(
272
- lambda p, m, r: generate_image(p, m, r),
273
- [prompt_input, model_choice, model_runner],
274
- [output_image]
275
- )
276
-
277
- return app
278
 
279
- def generate_image(prompt, model_type, runner):
280
- if model_type == "Pretrained":
281
- pipe = runner.load_pretrained()
282
- image = pipe(prompt).images[0]
283
- else:
284
- model = runner.load_custom()
285
- noise = torch.randn(1, 100)
286
- with torch.no_grad():
287
- fake = model(torch.randint(0, 1000, (1,)), noise)
288
- image = fake.squeeze().permute(1, 2, 0).numpy()
289
- image = (image + 1) / 2
290
- return Image.fromarray((image * 255).astype(np.uint8))
291
 
292
- class ModelRunner:
293
- def __init__(self):
294
- self.pretrained_pipe = None
295
-
296
- def load_pretrained(self):
297
- if not self.pretrained_pipe:
298
- self.pretrained_pipe = DiffusionPipeline.from_pretrained(
299
- "stabilityai/stable-diffusion-xl-base-1.0"
300
- )
301
- return self.pretrained_pipe
302
-
303
- def load_custom(self):
304
- model = TextConditionedGenerator()
305
- model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
306
- return model
307
 
308
- if __name__ == "__main__":
309
- interface = create_interface()
310
- interface.launch()
 
 
 
50
  self.scraping_progress = 0
51
  self.scraped_count = 0
52
  self.total_images = 0
53
+
54
  def __getstate__(self):
55
  state = self.__dict__.copy()
56
  del state['stop_event']
57
  del state['_lock']
58
  return state
59
+
60
  def __setstate__(self, state):
61
  self.__dict__.update(state)
62
  self.stop_event = threading.Event()
 
72
  soup = BeautifulSoup(response.content, 'html.parser')
73
  img_tags = soup.find_all('img', {'class': 'photo-item__img'})
74
  self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"])
75
+
76
+ for idx, img in enumerate(img_tags[:self.total_images]):
77
  if self.stop_event.is_set():
78
  break
79
+
80
  img_url = img['src']
81
  try:
82
  img_data = requests.get(img_url).content
83
  img_name = f"{int(time.time())}_{idx}.jpg"
84
+ os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
85
  img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
 
86
  with open(img_path, 'wb') as f:
87
  f.write(img_data)
 
88
  self.scraped_data.append({"text": query, "image": img_path})
89
  self.scraped_count = idx + 1
90
  self.scraping_progress = (idx + 1) / self.total_images * 100
 
91
  except Exception as e:
92
  print(f"Error downloading image: {e}")
93
+ time.sleep(0.1)
 
 
94
  except Exception as e:
95
  print(f"Scraping error: {e}")
96
  finally:
97
  self.scraping_progress = 100
98
+
99
  def start_scraping(self, query):
100
  self.stop_event.clear()
 
101
  thread = threading.Thread(target=self.scrape_images, args=(query,))
102
  thread.start()
103
  return "Scraping started..."
 
108
  class TextImageDataset(Dataset):
109
  def __init__(self, data):
110
  self.data = data
111
+
112
  def __len__(self):
113
  return len(self.data)
114
+
115
  def __getitem__(self, idx):
116
  item = self.data[idx]
117
  image = Image.open(item["image"]).convert('RGB')
 
131
  nn.Linear(512, 3*64*64),
132
  nn.Tanh()
133
  )
134
+
135
  def forward(self, text, noise):
136
  text_emb = self.text_embedding(text)
137
  combined = torch.cat([text_emb, noise], 1)
 
143
  def train_model(scraper, progress=gr.Progress()):
144
  if len(scraper.scraped_data) == 0:
145
  return "Error: No images scraped! Scrape images first."
146
+
147
  dataset = TextImageDataset(scraper.scraped_data)
148
+ dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True)
149
+
 
 
150
  generator = TextConditionedGenerator()
151
  discriminator = nn.Sequential(
152
  nn.Linear(3*64*64, 512),
 
154
  nn.Linear(512, 1),
155
  nn.Sigmoid()
156
  )
157
+
158
  optimizer_G = optim.Adam(generator.parameters(), lr=CONFIG["training"]["lr"])
159
  optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
160
  criterion = nn.BCELoss()
161
+
162
  total_batches = len(dataloader)
163
+ for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])):
164
  for batch_idx, batch in enumerate(dataloader):
165
+ real_imgs = torch.randn(4, 3, 64, 64)
166
  real_labels = torch.ones(real_imgs.size(0), 1)
167
  noise = torch.randn(real_imgs.size(0), 100)
168
+
 
169
  optimizer_D.zero_grad()
170
  real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
171
  fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
 
173
  d_loss = (real_loss + fake_loss) / 2
174
  d_loss.backward()
175
  optimizer_D.step()
176
+
 
177
  optimizer_G.zero_grad()
178
  g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
179
  g_loss.backward()
180
  optimizer_G.step()
181
+
 
 
 
 
 
 
182
  torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
183
  return f"Training complete! Used {len(dataset)} samples"
184
 
185
+ # ======================
186
+ # Image Generation
187
+ # ======================
188
+ class ModelRunner:
189
+ def __init__(self):
190
+ self.pretrained_pipe = None
191
+ self.custom_model = None
192
+
193
+ def load_pretrained(self):
194
+ if not self.pretrained_pipe:
195
+ self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
196
+ return self.pretrained_pipe
197
+
198
+ def load_custom(self):
199
+ if not self.custom_model:
200
+ model = TextConditionedGenerator()
201
+ model.load_state_dict(torch.load(CONFIG["paths"]["model_save"]))
202
+ model.eval()
203
+ self.custom_model = model
204
+ return self.custom_model
205
+
206
+ def generate_image(prompt, model_type, runner):
207
+ if model_type == "Pretrained":
208
+ pipe = runner.load_pretrained()
209
+ image = pipe(prompt).images[0]
210
+ return image
211
+ else:
212
+ model = runner.load_custom()
213
+ noise = torch.randn(1, 100)
214
+ with torch.no_grad():
215
+ fake = model(torch.randint(0, 1000, (1,)), noise)
216
+ image = fake.squeeze().permute(1, 2, 0).numpy()
217
+ image = (image + 1) / 2
218
+ return Image.fromarray((image * 255).astype(np.uint8))
219
+
220
  # ======================
221
  # Gradio Interface
222
  # ======================
223
  def create_interface():
224
  with gr.Blocks() as app:
225
+ scraper = gr.State(WebScraper())
226
+ runner = gr.State(ModelRunner())
227
+
228
  with gr.Row():
229
  with gr.Column():
230
  query_input = gr.Textbox(label="Search Query")
231
  scrape_btn = gr.Button("Start Scraping")
232
  scrape_status = gr.Textbox(label="Scraping Status")
233
+
 
234
  train_btn = gr.Button("Start Training")
235
  training_status = gr.Textbox(label="Training Status")
236
+
 
237
  with gr.Column():
238
  prompt_input = gr.Textbox(label="Generation Prompt")
239
  model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
240
  generate_btn = gr.Button("Generate Image")
241
  output_image = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ scrape_btn.click(lambda s, q: s.start_scraping(q), [scraper, query_input], [scrape_status])
244
+ train_btn.click(lambda s: train_model(s), [scraper], [training_status])
245
+ generate_btn.click(lambda p, m, r: generate_image(p, m, r), [prompt_input, model_choice, runner], [output_image])
 
 
 
 
 
 
 
 
 
246
 
247
+ return app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # ======================
250
+ # Launch
251
+ # ======================
252
+ app = create_interface()
253
+ app.launch()