Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ CONFIG = {
|
|
19 |
"scraping": {
|
20 |
"search_url": "https://www.pexels.com/search/{query}/",
|
21 |
"headers": {
|
22 |
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)
|
23 |
},
|
24 |
"max_images": 100,
|
25 |
"progress_interval": 1
|
@@ -77,7 +77,10 @@ class WebScraper:
|
|
77 |
if self.stop_event.is_set():
|
78 |
break
|
79 |
|
80 |
-
img_url = img
|
|
|
|
|
|
|
81 |
try:
|
82 |
img_data = requests.get(img_url).content
|
83 |
img_name = f"{int(time.time())}_{idx}.jpg"
|
@@ -97,6 +100,7 @@ class WebScraper:
|
|
97 |
self.scraping_progress = 100
|
98 |
|
99 |
def start_scraping(self, query):
|
|
|
100 |
self.stop_event.clear()
|
101 |
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
102 |
thread.start()
|
@@ -114,8 +118,14 @@ class TextImageDataset(Dataset):
|
|
114 |
|
115 |
def __getitem__(self, idx):
|
116 |
item = self.data[idx]
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
return {"text": item["text"], "image": image}
|
120 |
|
121 |
class TextConditionedGenerator(nn.Module):
|
@@ -159,23 +169,26 @@ def train_model(scraper, progress=gr.Progress()):
|
|
159 |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
160 |
criterion = nn.BCELoss()
|
161 |
|
162 |
-
total_batches = len(dataloader)
|
163 |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])):
|
164 |
-
for
|
165 |
-
real_imgs =
|
166 |
-
|
167 |
noise = torch.randn(real_imgs.size(0), 100)
|
|
|
|
|
168 |
|
|
|
169 |
optimizer_D.zero_grad()
|
170 |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
|
171 |
-
fake_imgs = generator(
|
172 |
-
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)),
|
173 |
d_loss = (real_loss + fake_loss) / 2
|
174 |
d_loss.backward()
|
175 |
optimizer_D.step()
|
176 |
|
|
|
177 |
optimizer_G.zero_grad()
|
178 |
-
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)),
|
179 |
g_loss.backward()
|
180 |
optimizer_G.step()
|
181 |
|
|
|
19 |
"scraping": {
|
20 |
"search_url": "https://www.pexels.com/search/{query}/",
|
21 |
"headers": {
|
22 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
|
23 |
},
|
24 |
"max_images": 100,
|
25 |
"progress_interval": 1
|
|
|
77 |
if self.stop_event.is_set():
|
78 |
break
|
79 |
|
80 |
+
img_url = img.get('src')
|
81 |
+
if not img_url:
|
82 |
+
continue
|
83 |
+
|
84 |
try:
|
85 |
img_data = requests.get(img_url).content
|
86 |
img_name = f"{int(time.time())}_{idx}.jpg"
|
|
|
100 |
self.scraping_progress = 100
|
101 |
|
102 |
def start_scraping(self, query):
|
103 |
+
self.scraped_data.clear()
|
104 |
self.stop_event.clear()
|
105 |
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
106 |
thread.start()
|
|
|
118 |
|
119 |
def __getitem__(self, idx):
|
120 |
item = self.data[idx]
|
121 |
+
try:
|
122 |
+
image = Image.open(item["image"]).convert('RGB')
|
123 |
+
image = image.resize((64, 64))
|
124 |
+
image = np.array(image).transpose(2, 0, 1) / 127.5 - 1
|
125 |
+
image = torch.tensor(image, dtype=torch.float32)
|
126 |
+
except Exception as e:
|
127 |
+
print(f"Error loading image: {e}")
|
128 |
+
image = torch.randn(3, 64, 64)
|
129 |
return {"text": item["text"], "image": image}
|
130 |
|
131 |
class TextConditionedGenerator(nn.Module):
|
|
|
169 |
optimizer_D = optim.Adam(discriminator.parameters(), lr=CONFIG["training"]["lr"])
|
170 |
criterion = nn.BCELoss()
|
171 |
|
|
|
172 |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"])):
|
173 |
+
for batch in dataloader:
|
174 |
+
real_imgs = batch["image"]
|
175 |
+
text_tokens = torch.randint(0, 1000, (real_imgs.size(0),))
|
176 |
noise = torch.randn(real_imgs.size(0), 100)
|
177 |
+
real_labels = torch.ones(real_imgs.size(0), 1)
|
178 |
+
fake_labels = torch.zeros(real_imgs.size(0), 1)
|
179 |
|
180 |
+
# Discriminator update
|
181 |
optimizer_D.zero_grad()
|
182 |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64*64)), real_labels)
|
183 |
+
fake_imgs = generator(text_tokens, noise)
|
184 |
+
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64*64)), fake_labels)
|
185 |
d_loss = (real_loss + fake_loss) / 2
|
186 |
d_loss.backward()
|
187 |
optimizer_D.step()
|
188 |
|
189 |
+
# Generator update
|
190 |
optimizer_G.zero_grad()
|
191 |
+
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64*64)), real_labels)
|
192 |
g_loss.backward()
|
193 |
optimizer_G.step()
|
194 |
|