gaur3009 commited on
Commit
afbe0d3
·
verified ·
1 Parent(s): 5db3cdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -19,7 +19,7 @@ CONFIG = {
19
  "scraping": {
20
  "search_url": "https://www.pexels.com/search/{query}/",
21
  "headers": {
22
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
23
  },
24
  "max_images": 100,
25
  "progress_interval": 1
@@ -77,7 +77,10 @@ class WebScraper:
77
  if self.stop_event.is_set():
78
  break
79
 
80
- img_url = img['src']
 
 
 
81
  try:
82
  img_data = requests.get(img_url).content
83
  img_name = f"{int(time.time())}_{idx}.jpg"
@@ -97,6 +100,7 @@ class WebScraper:
97
  self.scraping_progress = 100
98
 
99
  def start_scraping(self, query):
 
100
  self.stop_event.clear()
101
  thread = threading.Thread(target=self.scrape_images, args=(query,))
102
  thread.start()
@@ -114,8 +118,14 @@ class TextImageDataset(Dataset):
114
 
115
  def __getitem__(self, idx):
116
  item = self.data[idx]
117
- image = Image.open(item["image"]).convert('RGB')
118
- image = torch.randn(3, 64, 64) # Simplified for example
 
 
 
 
 
 
119
  return {"text": item["text"], "image": image}
120
 
121
  class TextConditionedGenerator(nn.Module):
@@ -159,23 +169,26 @@ def train_model(scraper, progress=gr.Progress()):
159
  optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
160
  criterion = nn.BCELoss()
161
 
162
- total_batches = len(dataloader)
163
  for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])):
164
- for batch_idx, batch in enumerate(dataloader):
165
- real_imgs = torch.randn(4, 3, 64, 64)
166
- real_labels = torch.ones(real_imgs.size(0), 1)
167
  noise = torch.randn(real_imgs.size(0), 100)
 
 
168
 
 
169
  optimizer_D.zero_grad()
170
  real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
171
- fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
172
- fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), torch.zeros_like(real_labels))
173
  d_loss = (real_loss + fake_loss) / 2
174
  d_loss.backward()
175
  optimizer_D.step()
176
 
 
177
  optimizer_G.zero_grad()
178
- g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), torch.ones_like(real_labels))
179
  g_loss.backward()
180
  optimizer_G.step()
181
 
 
19
  "scraping": {
20
  "search_url": "https://www.pexels.com/search/{query}/",
21
  "headers": {
22
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
23
  },
24
  "max_images": 100,
25
  "progress_interval": 1
 
77
  if self.stop_event.is_set():
78
  break
79
 
80
+ img_url = img.get('src')
81
+ if not img_url:
82
+ continue
83
+
84
  try:
85
  img_data = requests.get(img_url).content
86
  img_name = f"{int(time.time())}_{idx}.jpg"
 
100
  self.scraping_progress = 100
101
 
102
  def start_scraping(self, query):
103
+ self.scraped_data.clear()
104
  self.stop_event.clear()
105
  thread = threading.Thread(target=self.scrape_images, args=(query,))
106
  thread.start()
 
118
 
119
  def __getitem__(self, idx):
120
  item = self.data[idx]
121
+ try:
122
+ image = Image.open(item["image"]).convert('RGB')
123
+ image = image.resize((64, 64))
124
+ image = np.array(image).transpose(2, 0, 1) / 127.5 - 1
125
+ image = torch.tensor(image, dtype=torch.float32)
126
+ except Exception as e:
127
+ print(f"Error loading image: {e}")
128
+ image = torch.randn(3, 64, 64)
129
  return {"text": item["text"], "image": image}
130
 
131
  class TextConditionedGenerator(nn.Module):
 
169
  optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
170
  criterion = nn.BCELoss()
171
 
 
172
  for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])):
173
+ for batch in dataloader:
174
+ real_imgs = batch["image"]
175
+ text_tokens = torch.randint(0, 1000, (real_imgs.size(0),))
176
  noise = torch.randn(real_imgs.size(0), 100)
177
+ real_labels = torch.ones(real_imgs.size(0), 1)
178
+ fake_labels = torch.zeros(real_imgs.size(0), 1)
179
 
180
+ # Discriminator update
181
  optimizer_D.zero_grad()
182
  real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
183
+ fake_imgs = generator(text_tokens, noise)
184
+ fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), fake_labels)
185
  d_loss = (real_loss + fake_loss) / 2
186
  d_loss.backward()
187
  optimizer_D.step()
188
 
189
+ # Generator update
190
  optimizer_G.zero_grad()
191
+ g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), real_labels)
192
  g_loss.backward()
193
  optimizer_G.step()
194