gaur3009 commited on
Commit
f2c8695
·
verified ·
1 Parent(s): 4250d40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -76
app.py CHANGED
@@ -10,7 +10,6 @@ import os
10
  import time
11
  import threading
12
  from PIL import Image
13
- import io
14
  import numpy as np
15
 
16
  # ======================
@@ -23,7 +22,7 @@ CONFIG = {
23
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
24
  },
25
  "max_images": 100,
26
- "scrape_time": 10 # 3 hours in seconds (simulated for testing)
27
  },
28
  "training": {
29
  "batch_size": 4,
@@ -46,39 +45,48 @@ class WebScraper:
46
  def __init__(self):
47
  self.stop_event = threading.Event()
48
  self.scraped_data = []
 
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  def scrape_images(self, query):
51
- search_url = CONFIG["scraping"]["search_url"].format(query=query)
52
- try:
53
- response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
54
- soup = BeautifulSoup(response.content, 'html.parser')
55
-
56
- # Extract image URLs (example selector - needs adjustment for actual site)
57
- img_tags = soup.find_all('img', {'class': 'photo-item__img'})
58
- for img in img_tags[:CONFIG["scraping"]["max_images"]]:
59
- if self.stop_event.is_set():
60
- break
61
- img_url = img['src']
62
- try:
63
- img_data = requests.get(img_url).content
64
- img_name = f"{int(time.time())}.jpg"
65
- img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
66
-
67
- with open(img_path, 'wb') as f:
68
- f.write(img_data)
69
-
70
- # Store text-image pair (text = query)
71
- self.scraped_data.append({"text": query, "image": img_path})
72
- except Exception as e:
73
- print(f"Error downloading image: {e}")
74
- except Exception as e:
75
- print(f"Scraping error: {e}")
76
 
77
  def start_scraping(self, query):
78
  self.stop_event.clear()
79
- if not os.path.exists(CONFIG["paths"]["dataset_dir"]):
80
- os.makedirs(CONFIG["paths"]["dataset_dir"])
81
-
82
  thread = threading.Thread(target=self.scrape_images, args=(query,))
83
  thread.start()
84
  return "Scraping started..."
@@ -97,17 +105,12 @@ class TextImageDataset(Dataset):
97
  def __getitem__(self, idx):
98
  item = self.data[idx]
99
  image = Image.open(item["image"]).convert('RGB')
100
-
101
- if self.transform:
102
- image = self.transform(image)
103
-
104
  return {"text": item["text"], "image": image}
105
 
106
- # Simplified Text-to-Image Generator
107
  class TextConditionedGenerator(nn.Module):
108
  def __init__(self):
109
  super().__init__()
110
- self.text_embedding = nn.Embedding(1000, 128) # Simplified text embedding
111
  self.model = nn.Sequential(
112
  nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
113
  nn.LeakyReLU(0.2),
@@ -121,8 +124,7 @@ class TextConditionedGenerator(nn.Module):
121
  def forward(self, text, noise):
122
  text_emb = self.text_embedding(text)
123
  combined = torch.cat([text_emb, noise], 1)
124
- img = self.model(combined)
125
- return img.view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
126
 
127
  # ======================
128
  # Training Utilities
@@ -144,29 +146,26 @@ def train_model(scraper, progress=gr.Progress()):
144
  criterion = nn.BCELoss()
145
 
146
  for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
147
- for i, batch in enumerate(dataloader):
148
- # Train discriminator
149
  real_imgs = batch["image"]
150
  real_labels = torch.ones(real_imgs.size(0), 1)
151
-
152
  noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
153
- fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
154
- fake_labels = torch.zeros(real_imgs.size(0), 1)
155
 
 
156
  optimizer_D.zero_grad()
157
  real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
158
- fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), fake_labels)
159
- d_loss = real_loss + fake_loss
 
160
  d_loss.backward()
161
  optimizer_D.step()
162
 
163
- # Train generator
164
  optimizer_G.zero_grad()
165
- validity = discriminator(fake_imgs.view(-1, 3*64**2))
166
- g_loss = criterion(validity, torch.ones_like(validity))
167
  g_loss.backward()
168
  optimizer_G.step()
169
-
170
  torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
171
  return "Training completed!"
172
 
@@ -179,16 +178,14 @@ class ModelRunner:
179
  self.custom_model = None
180
 
181
  def load_pretrained(self):
182
- if self.pretrained_pipe is None:
183
- self.pretrained_pipe = DiffusionPipeline.from_pretrained(
184
- "stabilityai/stable-diffusion-xl-base-1.0"
185
- )
186
  return self.pretrained_pipe
187
 
188
  def load_custom(self):
189
- if self.custom_model is None:
190
  model = TextConditionedGenerator()
191
- model.load_state_dict(torch.load(CONFIG["paths"]["model_save"]))
192
  self.custom_model = model
193
  return self.custom_model
194
 
@@ -196,42 +193,39 @@ class ModelRunner:
196
  # Gradio Interface
197
  # ======================
198
  with gr.Blocks() as app:
199
- # Use Gradio's state management
200
- scraper_state = gr.State(WebScraper())
201
- model_runner_state = gr.State(ModelRunner())
202
 
203
  with gr.Row():
204
  with gr.Column():
205
  query_input = gr.Textbox(label="Search Query")
206
  scrape_btn = gr.Button("Start Scraping")
207
  scrape_status = gr.Textbox(label="Scraping Status")
208
-
209
  train_btn = gr.Button("Start Training")
210
  training_status = gr.Textbox(label="Training Status")
211
 
212
  with gr.Column():
213
  prompt_input = gr.Textbox(label="Generation Prompt")
214
- model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type")
215
  generate_btn = gr.Button("Generate Image")
216
  output_image = gr.Image(label="Generated Image")
217
-
218
- # Event Handlers
219
  scrape_btn.click(
220
- fn=lambda scraper, query: scraper.start_scraping(query),
221
- inputs=[scraper_state, query_input],
222
- outputs=scrape_status
223
  )
224
 
225
  train_btn.click(
226
- fn=lambda scraper: train_model(scraper),
227
- inputs=[scraper_state],
228
- outputs=training_status
229
  )
230
 
231
  generate_btn.click(
232
- fn=lambda prompt, model_type, runner: generate_image(prompt, model_type, runner),
233
- inputs=[prompt_input, model_choice, model_runner_state],
234
- outputs=output_image
235
  )
236
 
237
  def generate_image(prompt, model_type, runner):
@@ -240,12 +234,11 @@ def generate_image(prompt, model_type, runner):
240
  image = pipe(prompt).images[0]
241
  else:
242
  model = runner.load_custom()
243
- # Simplified generation process
244
  noise = torch.randn(1, CONFIG["training"]["latent_dim"])
245
- fake = model(torch.randint(0, 1000, (1,)), noise).detach()
246
- image = fake.squeeze().permute(1,2,0).numpy()
247
- image = (image + 1) / 2 # Scale to [0,1]
248
-
249
  return Image.fromarray((image * 255).astype(np.uint8))
250
 
251
  if __name__ == "__main__":
 
10
  import time
11
  import threading
12
  from PIL import Image
 
13
  import numpy as np
14
 
15
  # ======================
 
22
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
23
  },
24
  "max_images": 100,
25
+ "scrape_time": 10
26
  },
27
  "training": {
28
  "batch_size": 4,
 
45
  def __init__(self):
46
  self.stop_event = threading.Event()
47
  self.scraped_data = []
48
+ self._lock = threading.Lock()
49
 
50
+ def __getstate__(self):
51
+ state = self.__dict__.copy()
52
+ del state['stop_event']
53
+ del state['_lock']
54
+ return state
55
+
56
+ def __setstate__(self, state):
57
+ self.__dict__.update(state)
58
+ self.stop_event = threading.Event()
59
+ self._lock = threading.Lock()
60
+
61
  def scrape_images(self, query):
62
+ with self._lock:
63
+ search_url = CONFIG["scraping"]["search_url"].format(query=query)
64
+ try:
65
+ response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
66
+ soup = BeautifulSoup(response.content, 'html.parser')
67
+ img_tags = soup.find_all('img', {'class': 'photo-item__img'})
68
+
69
+ for img in img_tags[:CONFIG["scraping"]["max_images"]]:
70
+ if self.stop_event.is_set():
71
+ break
72
+ img_url = img['src']
73
+ try:
74
+ img_data = requests.get(img_url).content
75
+ img_name = f"{int(time.time())}.jpg"
76
+ img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
77
+
78
+ with open(img_path, 'wb') as f:
79
+ f.write(img_data)
80
+
81
+ self.scraped_data.append({"text": query, "image": img_path})
82
+ except Exception as e:
83
+ print(f"Error downloading image: {e}")
84
+ except Exception as e:
85
+ print(f"Scraping error: {e}")
 
86
 
87
  def start_scraping(self, query):
88
  self.stop_event.clear()
89
+ os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
 
 
90
  thread = threading.Thread(target=self.scrape_images, args=(query,))
91
  thread.start()
92
  return "Scraping started..."
 
105
  def __getitem__(self, idx):
106
  item = self.data[idx]
107
  image = Image.open(item["image"]).convert('RGB')
 
 
 
 
108
  return {"text": item["text"], "image": image}
109
 
 
110
  class TextConditionedGenerator(nn.Module):
111
  def __init__(self):
112
  super().__init__()
113
+ self.text_embedding = nn.Embedding(1000, 128)
114
  self.model = nn.Sequential(
115
  nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
116
  nn.LeakyReLU(0.2),
 
124
  def forward(self, text, noise):
125
  text_emb = self.text_embedding(text)
126
  combined = torch.cat([text_emb, noise], 1)
127
+ return self.model(combined).view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
 
128
 
129
  # ======================
130
  # Training Utilities
 
146
  criterion = nn.BCELoss()
147
 
148
  for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
149
+ for batch in dataloader:
 
150
  real_imgs = batch["image"]
151
  real_labels = torch.ones(real_imgs.size(0), 1)
 
152
  noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
 
 
153
 
154
+ # Discriminator training
155
  optimizer_D.zero_grad()
156
  real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
157
+ fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
158
+ fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), torch.zeros_like(real_labels))
159
+ d_loss = (real_loss + fake_loss) / 2
160
  d_loss.backward()
161
  optimizer_D.step()
162
 
163
+ # Generator training
164
  optimizer_G.zero_grad()
165
+ g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64**2)), torch.ones_like(real_labels))
 
166
  g_loss.backward()
167
  optimizer_G.step()
168
+
169
  torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
170
  return "Training completed!"
171
 
 
178
  self.custom_model = None
179
 
180
  def load_pretrained(self):
181
+ if not self.pretrained_pipe:
182
+ self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
 
 
183
  return self.pretrained_pipe
184
 
185
  def load_custom(self):
186
+ if not self.custom_model:
187
  model = TextConditionedGenerator()
188
+ model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
189
  self.custom_model = model
190
  return self.custom_model
191
 
 
193
  # Gradio Interface
194
  # ======================
195
  with gr.Blocks() as app:
196
+ scraper_state = gr.State(WebScraper)
197
+ model_runner_state = gr.State(ModelRunner)
 
198
 
199
  with gr.Row():
200
  with gr.Column():
201
  query_input = gr.Textbox(label="Search Query")
202
  scrape_btn = gr.Button("Start Scraping")
203
  scrape_status = gr.Textbox(label="Scraping Status")
 
204
  train_btn = gr.Button("Start Training")
205
  training_status = gr.Textbox(label="Training Status")
206
 
207
  with gr.Column():
208
  prompt_input = gr.Textbox(label="Generation Prompt")
209
+ model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
210
  generate_btn = gr.Button("Generate Image")
211
  output_image = gr.Image(label="Generated Image")
212
+
 
213
  scrape_btn.click(
214
+ lambda scraper, query: scraper.start_scraping(query),
215
+ [scraper_state, query_input],
216
+ scrape_status
217
  )
218
 
219
  train_btn.click(
220
+ lambda scraper: train_model(scraper),
221
+ [scraper_state],
222
+ training_status
223
  )
224
 
225
  generate_btn.click(
226
+ lambda prompt, model_type, runner: generate_image(prompt, model_type, runner),
227
+ [prompt_input, model_choice, model_runner_state],
228
+ output_image
229
  )
230
 
231
  def generate_image(prompt, model_type, runner):
 
234
  image = pipe(prompt).images[0]
235
  else:
236
  model = runner.load_custom()
 
237
  noise = torch.randn(1, CONFIG["training"]["latent_dim"])
238
+ with torch.no_grad():
239
+ fake = model(torch.randint(0, 1000, (1,)), noise)
240
+ image = fake.squeeze().permute(1, 2, 0).numpy()
241
+ image = (image + 1) / 2
242
  return Image.fromarray((image * 255).astype(np.uint8))
243
 
244
  if __name__ == "__main__":