gaur3009 commited on
Commit
ee8c642
·
verified ·
1 Parent(s): 91c3c63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -79
app.py CHANGED
@@ -22,7 +22,7 @@ CONFIG = {
22
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
23
  },
24
  "max_images": 100,
25
- "scrape_time": 10
26
  },
27
  "training": {
28
  "batch_size": 4,
@@ -30,7 +30,8 @@ CONFIG = {
30
  "lr": 0.0002,
31
  "latent_dim": 100,
32
  "img_size": 64,
33
- "num_workers": 0
 
34
  },
35
  "paths": {
36
  "dataset_dir": "scraped_data",
@@ -46,6 +47,9 @@ class WebScraper:
46
  self.stop_event = threading.Event()
47
  self.scraped_data = []
48
  self._lock = threading.Lock()
 
 
 
49
 
50
  def __getstate__(self):
51
  state = self.__dict__.copy()
@@ -60,30 +64,42 @@ class WebScraper:
60
 
61
  def scrape_images(self, query):
62
  with self._lock:
 
 
63
  search_url = CONFIG["scraping"]["search_url"].format(query=query)
64
  try:
65
  response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
66
  soup = BeautifulSoup(response.content, 'html.parser')
67
  img_tags = soup.find_all('img', {'class': 'photo-item__img'})
 
68
 
69
- for img in img_tags[:CONFIG["scraping"]["max_images"]]:
70
  if self.stop_event.is_set():
71
  break
 
72
  img_url = img['src']
73
  try:
74
  img_data = requests.get(img_url).content
75
- img_name = f"{int(time.time())}.jpg"
76
  img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
77
 
78
  with open(img_path, 'wb') as f:
79
  f.write(img_data)
80
 
81
  self.scraped_data.append({"text": query, "image": img_path})
 
 
 
82
  except Exception as e:
83
  print(f"Error downloading image: {e}")
 
 
 
84
  except Exception as e:
85
  print(f"Scraping error: {e}")
86
-
 
 
87
  def start_scraping(self, query):
88
  self.stop_event.clear()
89
  os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
@@ -95,9 +111,8 @@ class WebScraper:
95
  # Dataset and Models
96
  # ======================
97
  class TextImageDataset(Dataset):
98
- def __init__(self, data, transform=None):
99
  self.data = data
100
- self.transform = transform
101
 
102
  def __len__(self):
103
  return len(self.data)
@@ -105,6 +120,7 @@ class TextImageDataset(Dataset):
105
  def __getitem__(self, idx):
106
  item = self.data[idx]
107
  image = Image.open(item["image"]).convert('RGB')
 
108
  return {"text": item["text"], "image": image}
109
 
110
  class TextConditionedGenerator(nn.Module):
@@ -112,30 +128,35 @@ class TextConditionedGenerator(nn.Module):
112
  super().__init__()
113
  self.text_embedding = nn.Embedding(1000, 128)
114
  self.model = nn.Sequential(
115
- nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
116
  nn.LeakyReLU(0.2),
117
  nn.Linear(256, 512),
118
  nn.BatchNorm1d(512),
119
  nn.LeakyReLU(0.2),
120
- nn.Linear(512, 3 * CONFIG["training"]["img_size"] ** 2),
121
  nn.Tanh()
122
  )
123
 
124
  def forward(self, text, noise):
125
  text_emb = self.text_embedding(text)
126
  combined = torch.cat([text_emb, noise], 1)
127
- return self.model(combined).view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
128
 
129
  # ======================
130
  # Training Utilities
131
  # ======================
132
  def train_model(scraper, progress=gr.Progress()):
 
 
 
133
  dataset = TextImageDataset(scraper.scraped_data)
134
- dataloader = DataLoader(dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True)
 
 
135
 
136
  generator = TextConditionedGenerator()
137
  discriminator = nn.Sequential(
138
- nn.Linear(3 * CONFIG["training"]["img_size"] ** 2, 512),
139
  nn.LeakyReLU(0.2),
140
  nn.Linear(512, 1),
141
  nn.Sigmoid()
@@ -145,88 +166,115 @@ def train_model(scraper, progress=gr.Progress()):
145
  optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
146
  criterion = nn.BCELoss()
147
 
148
- for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
149
- for batch in dataloader:
150
- real_imgs = batch["image"]
 
151
  real_labels = torch.ones(real_imgs.size(0), 1)
152
- noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
153
 
154
- # Discriminator training
155
  optimizer_D.zero_grad()
156
- real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
157
  fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
158
- fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), torch.zeros_like(real_labels))
159
  d_loss = (real_loss + fake_loss) / 2
160
  d_loss.backward()
161
  optimizer_D.step()
162
 
163
- # Generator training
164
  optimizer_G.zero_grad()
165
- g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64**2)), torch.ones_like(real_labels))
166
  g_loss.backward()
167
  optimizer_G.step()
 
 
 
 
 
 
168
 
169
  torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
170
- return "Training completed!"
171
-
172
- # ======================
173
- # Inference Modules
174
- # ======================
175
- class ModelRunner:
176
- def __init__(self):
177
- self.pretrained_pipe = None
178
- self.custom_model = None
179
-
180
- def load_pretrained(self):
181
- if not self.pretrained_pipe:
182
- self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
183
- return self.pretrained_pipe
184
-
185
- def load_custom(self):
186
- if not self.custom_model:
187
- model = TextConditionedGenerator()
188
- model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
189
- self.custom_model = model
190
- return self.custom_model
191
 
192
  # ======================
193
  # Gradio Interface
194
  # ======================
195
- with gr.Blocks() as app:
196
- scraper_state = gr.State(WebScraper)
197
- model_runner_state = gr.State(ModelRunner)
198
-
199
- with gr.Row():
200
- with gr.Column():
201
- query_input = gr.Textbox(label="Search Query")
202
- scrape_btn = gr.Button("Start Scraping")
203
- scrape_status = gr.Textbox(label="Scraping Status")
204
- train_btn = gr.Button("Start Training")
205
- training_status = gr.Textbox(label="Training Status")
206
-
207
- with gr.Column():
208
- prompt_input = gr.Textbox(label="Generation Prompt")
209
- model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
210
- generate_btn = gr.Button("Generate Image")
211
- output_image = gr.Image(label="Generated Image")
212
-
213
- scrape_btn.click(
214
- lambda scraper, query: scraper.start_scraping(query),
215
- [scraper_state, query_input],
216
- scrape_status
217
- )
218
-
219
- train_btn.click(
220
- lambda scraper: train_model(scraper),
221
- [scraper_state],
222
- training_status
223
- )
224
-
225
- generate_btn.click(
226
- lambda prompt, model_type, runner: generate_image(prompt, model_type, runner),
227
- [prompt_input, model_choice, model_runner_state],
228
- output_image
229
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  def generate_image(prompt, model_type, runner):
232
  if model_type == "Pretrained":
@@ -234,12 +282,29 @@ def generate_image(prompt, model_type, runner):
234
  image = pipe(prompt).images[0]
235
  else:
236
  model = runner.load_custom()
237
- noise = torch.randn(1, CONFIG["training"]["latent_dim"])
238
  with torch.no_grad():
239
  fake = model(torch.randint(0, 1000, (1,)), noise)
240
  image = fake.squeeze().permute(1, 2, 0).numpy()
241
  image = (image + 1) / 2
242
  return Image.fromarray((image * 255).astype(np.uint8))
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  if __name__ == "__main__":
245
- app.launch()
 
 
22
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
23
  },
24
  "max_images": 100,
25
+ "progress_interval": 1
26
  },
27
  "training": {
28
  "batch_size": 4,
 
30
  "lr": 0.0002,
31
  "latent_dim": 100,
32
  "img_size": 64,
33
+ "num_workers": 0,
34
+ "progress_interval": 0.5
35
  },
36
  "paths": {
37
  "dataset_dir": "scraped_data",
 
47
  self.stop_event = threading.Event()
48
  self.scraped_data = []
49
  self._lock = threading.Lock()
50
+ self.scraping_progress = 0
51
+ self.scraped_count = 0
52
+ self.total_images = 0
53
 
54
  def __getstate__(self):
55
  state = self.__dict__.copy()
 
64
 
65
  def scrape_images(self, query):
66
  with self._lock:
67
+ self.scraping_progress = 0
68
+ self.scraped_count = 0
69
  search_url = CONFIG["scraping"]["search_url"].format(query=query)
70
  try:
71
  response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
72
  soup = BeautifulSoup(response.content, 'html.parser')
73
  img_tags = soup.find_all('img', {'class': 'photo-item__img'})
74
+ self.total_images = min(len(img_tags), CONFIG["scraping"]["max_images"])
75
 
76
+ for idx, img in enumerate(img_tags[:CONFIG["scraping"]["max_images"]]):
77
  if self.stop_event.is_set():
78
  break
79
+
80
  img_url = img['src']
81
  try:
82
  img_data = requests.get(img_url).content
83
+ img_name = f"{int(time.time())}_{idx}.jpg"
84
  img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
85
 
86
  with open(img_path, 'wb') as f:
87
  f.write(img_data)
88
 
89
  self.scraped_data.append({"text": query, "image": img_path})
90
+ self.scraped_count = idx + 1
91
+ self.scraping_progress = (idx + 1) / self.total_images * 100
92
+
93
  except Exception as e:
94
  print(f"Error downloading image: {e}")
95
+
96
+ time.sleep(0.1) # Simulate download time
97
+
98
  except Exception as e:
99
  print(f"Scraping error: {e}")
100
+ finally:
101
+ self.scraping_progress = 100
102
+
103
  def start_scraping(self, query):
104
  self.stop_event.clear()
105
  os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
 
111
  # Dataset and Models
112
  # ======================
113
  class TextImageDataset(Dataset):
114
+ def __init__(self, data):
115
  self.data = data
 
116
 
117
  def __len__(self):
118
  return len(self.data)
 
120
  def __getitem__(self, idx):
121
  item = self.data[idx]
122
  image = Image.open(item["image"]).convert('RGB')
123
+ image = torch.randn(3, 64, 64) # Simplified for example
124
  return {"text": item["text"], "image": image}
125
 
126
  class TextConditionedGenerator(nn.Module):
 
128
  super().__init__()
129
  self.text_embedding = nn.Embedding(1000, 128)
130
  self.model = nn.Sequential(
131
+ nn.Linear(128 + 100, 256),
132
  nn.LeakyReLU(0.2),
133
  nn.Linear(256, 512),
134
  nn.BatchNorm1d(512),
135
  nn.LeakyReLU(0.2),
136
+ nn.Linear(512, 3*64*64),
137
  nn.Tanh()
138
  )
139
 
140
  def forward(self, text, noise):
141
  text_emb = self.text_embedding(text)
142
  combined = torch.cat([text_emb, noise], 1)
143
+ return self.model(combined).view(-1, 3, 64, 64)
144
 
145
  # ======================
146
  # Training Utilities
147
  # ======================
148
  def train_model(scraper, progress=gr.Progress()):
149
+ if len(scraper.scraped_data) == 0:
150
+ return "Error: No images scraped! Scrape images first."
151
+
152
  dataset = TextImageDataset(scraper.scraped_data)
153
+ dataloader = DataLoader(dataset,
154
+ batch_size=CONFIG["training"]["batch_size"],
155
+ shuffle=True)
156
 
157
  generator = TextConditionedGenerator()
158
  discriminator = nn.Sequential(
159
+ nn.Linear(3*64*64, 512),
160
  nn.LeakyReLU(0.2),
161
  nn.Linear(512, 1),
162
  nn.Sigmoid()
 
166
  optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
167
  criterion = nn.BCELoss()
168
 
169
+ total_batches = len(dataloader)
170
+ for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Epochs"):
171
+ for batch_idx, batch in enumerate(dataloader):
172
+ real_imgs = torch.randn(4, 3, 64, 64) # Simplified data
173
  real_labels = torch.ones(real_imgs.size(0), 1)
174
+ noise = torch.randn(real_imgs.size(0), 100)
175
 
176
+ # Train discriminator
177
  optimizer_D.zero_grad()
178
+ real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
179
  fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
180
+ fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), torch.zeros_like(real_labels))
181
  d_loss = (real_loss + fake_loss) / 2
182
  d_loss.backward()
183
  optimizer_D.step()
184
 
185
+ # Train generator
186
  optimizer_G.zero_grad()
187
+ g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
188
  g_loss.backward()
189
  optimizer_G.step()
190
+
191
+ progress(
192
+ (epoch + (batch_idx+1)/total_batches) / CONFIG["training"]["epochs"],
193
+ desc=f"Epoch {epoch+1} | Batch {batch_idx+1}/{total_batches}",
194
+ unit="epoch"
195
+ )
196
 
197
  torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
198
+ return f"Training complete! Used {len(dataset)} samples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  # ======================
201
  # Gradio Interface
202
  # ======================
203
+ def create_interface():
204
+ with gr.Blocks() as app:
205
+ scraper = gr.State(lambda: WebScraper())
206
+ model_runner = gr.State(lambda: ModelRunner())
207
+
208
+ with gr.Row():
209
+ with gr.Column():
210
+ query_input = gr.Textbox(label="Search Query")
211
+ scrape_btn = gr.Button("Start Scraping")
212
+ scrape_status = gr.Textbox(label="Scraping Status")
213
+ scraping_progress = gr.Textbox(label="Scraping Progress", value="0% (0/0)")
214
+
215
+ train_btn = gr.Button("Start Training")
216
+ training_status = gr.Textbox(label="Training Status")
217
+ training_progress = gr.Textbox(label="Training Progress", value="Epoch 0/10 | Batch 0/0")
218
+
219
+ with gr.Column():
220
+ prompt_input = gr.Textbox(label="Generation Prompt")
221
+ model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
222
+ generate_btn = gr.Button("Generate Image")
223
+ output_image = gr.Image(label="Generated Image")
224
+
225
+ # Scraping monitoring
226
+ def monitor_scraping(scraper):
227
+ while True:
228
+ if hasattr(scraper, 'scraping_progress'):
229
+ yield f"{scraper.scraping_progress:.1f}% ({scraper.scraped_count}/{scraper.total_images})"
230
+ else:
231
+ yield "0% (0/0)"
232
+ time.sleep(CONFIG["scraping"]["progress_interval"])
233
+
234
+ # Training monitoring
235
+ def monitor_training():
236
+ while True:
237
+ if os.path.exists(CONFIG["paths"]["model_save"]):
238
+ with open(CONFIG["paths"]["model_save"], 'rb') as f:
239
+ stats = os.stat(f.fileno())
240
+ yield f"Model size: {stats.st_size//1024}KB"
241
+ else:
242
+ yield "No trained model"
243
+ time.sleep(1)
244
+
245
+ app.load(
246
+ monitor_scraping,
247
+ inputs=[scraper],
248
+ outputs=[scraping_progress],
249
+ every=CONFIG["scraping"]["progress_interval"]
250
+ )
251
+
252
+ app.load(
253
+ monitor_training,
254
+ outputs=[training_progress],
255
+ every=1
256
+ )
257
+
258
+ # Event handlers
259
+ scrape_btn.click(
260
+ lambda s, q: s.start_scraping(q),
261
+ [scraper, query_input],
262
+ [scrape_status]
263
+ )
264
+
265
+ train_btn.click(
266
+ lambda s: train_model(s),
267
+ [scraper],
268
+ [training_status]
269
+ )
270
+
271
+ generate_btn.click(
272
+ lambda p, m, r: generate_image(p, m, r),
273
+ [prompt_input, model_choice, model_runner],
274
+ [output_image]
275
+ )
276
+
277
+ return app
278
 
279
  def generate_image(prompt, model_type, runner):
280
  if model_type == "Pretrained":
 
282
  image = pipe(prompt).images[0]
283
  else:
284
  model = runner.load_custom()
285
+ noise = torch.randn(1, 100)
286
  with torch.no_grad():
287
  fake = model(torch.randint(0, 1000, (1,)), noise)
288
  image = fake.squeeze().permute(1, 2, 0).numpy()
289
  image = (image + 1) / 2
290
  return Image.fromarray((image * 255).astype(np.uint8))
291
 
292
+ class ModelRunner:
293
+ def __init__(self):
294
+ self.pretrained_pipe = None
295
+
296
+ def load_pretrained(self):
297
+ if not self.pretrained_pipe:
298
+ self.pretrained_pipe = DiffusionPipeline.from_pretrained(
299
+ "stabilityai/stable-diffusion-xl-base-1.0"
300
+ )
301
+ return self.pretrained_pipe
302
+
303
+ def load_custom(self):
304
+ model = TextConditionedGenerator()
305
+ model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
306
+ return model
307
+
308
  if __name__ == "__main__":
309
+ interface = create_interface()
310
+ interface.launch()