Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,6 @@ import os
|
|
10 |
import time
|
11 |
import threading
|
12 |
from PIL import Image
|
13 |
-
import io
|
14 |
import numpy as np
|
15 |
|
16 |
# ======================
|
@@ -23,7 +22,7 @@ CONFIG = {
|
|
23 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
24 |
},
|
25 |
"max_images": 100,
|
26 |
-
"scrape_time": 10
|
27 |
},
|
28 |
"training": {
|
29 |
"batch_size": 4,
|
@@ -46,39 +45,48 @@ class WebScraper:
|
|
46 |
def __init__(self):
|
47 |
self.stop_event = threading.Event()
|
48 |
self.scraped_data = []
|
|
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def scrape_images(self, query):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
print(f"Scraping error: {e}")
|
76 |
|
77 |
def start_scraping(self, query):
|
78 |
self.stop_event.clear()
|
79 |
-
|
80 |
-
os.makedirs(CONFIG["paths"]["dataset_dir"])
|
81 |
-
|
82 |
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
83 |
thread.start()
|
84 |
return "Scraping started..."
|
@@ -97,17 +105,12 @@ class TextImageDataset(Dataset):
|
|
97 |
def __getitem__(self, idx):
|
98 |
item = self.data[idx]
|
99 |
image = Image.open(item["image"]).convert('RGB')
|
100 |
-
|
101 |
-
if self.transform:
|
102 |
-
image = self.transform(image)
|
103 |
-
|
104 |
return {"text": item["text"], "image": image}
|
105 |
|
106 |
-
# Simplified Text-to-Image Generator
|
107 |
class TextConditionedGenerator(nn.Module):
|
108 |
def __init__(self):
|
109 |
super().__init__()
|
110 |
-
self.text_embedding = nn.Embedding(1000, 128)
|
111 |
self.model = nn.Sequential(
|
112 |
nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
|
113 |
nn.LeakyReLU(0.2),
|
@@ -121,8 +124,7 @@ class TextConditionedGenerator(nn.Module):
|
|
121 |
def forward(self, text, noise):
|
122 |
text_emb = self.text_embedding(text)
|
123 |
combined = torch.cat([text_emb, noise], 1)
|
124 |
-
|
125 |
-
return img.view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
|
126 |
|
127 |
# ======================
|
128 |
# Training Utilities
|
@@ -144,29 +146,26 @@ def train_model(scraper, progress=gr.Progress()):
|
|
144 |
criterion = nn.BCELoss()
|
145 |
|
146 |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
|
147 |
-
for
|
148 |
-
# Train discriminator
|
149 |
real_imgs = batch["image"]
|
150 |
real_labels = torch.ones(real_imgs.size(0), 1)
|
151 |
-
|
152 |
noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
|
153 |
-
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
154 |
-
fake_labels = torch.zeros(real_imgs.size(0), 1)
|
155 |
|
|
|
156 |
optimizer_D.zero_grad()
|
157 |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
|
158 |
-
|
159 |
-
|
|
|
160 |
d_loss.backward()
|
161 |
optimizer_D.step()
|
162 |
|
163 |
-
#
|
164 |
optimizer_G.zero_grad()
|
165 |
-
|
166 |
-
g_loss = criterion(validity, torch.ones_like(validity))
|
167 |
g_loss.backward()
|
168 |
optimizer_G.step()
|
169 |
-
|
170 |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
171 |
return "Training completed!"
|
172 |
|
@@ -179,16 +178,14 @@ class ModelRunner:
|
|
179 |
self.custom_model = None
|
180 |
|
181 |
def load_pretrained(self):
|
182 |
-
if self.pretrained_pipe
|
183 |
-
self.pretrained_pipe = DiffusionPipeline.from_pretrained(
|
184 |
-
"stabilityai/stable-diffusion-xl-base-1.0"
|
185 |
-
)
|
186 |
return self.pretrained_pipe
|
187 |
|
188 |
def load_custom(self):
|
189 |
-
if self.custom_model
|
190 |
model = TextConditionedGenerator()
|
191 |
-
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"]))
|
192 |
self.custom_model = model
|
193 |
return self.custom_model
|
194 |
|
@@ -196,42 +193,39 @@ class ModelRunner:
|
|
196 |
# Gradio Interface
|
197 |
# ======================
|
198 |
with gr.Blocks() as app:
|
199 |
-
|
200 |
-
|
201 |
-
model_runner_state = gr.State(ModelRunner())
|
202 |
|
203 |
with gr.Row():
|
204 |
with gr.Column():
|
205 |
query_input = gr.Textbox(label="Search Query")
|
206 |
scrape_btn = gr.Button("Start Scraping")
|
207 |
scrape_status = gr.Textbox(label="Scraping Status")
|
208 |
-
|
209 |
train_btn = gr.Button("Start Training")
|
210 |
training_status = gr.Textbox(label="Training Status")
|
211 |
|
212 |
with gr.Column():
|
213 |
prompt_input = gr.Textbox(label="Generation Prompt")
|
214 |
-
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type")
|
215 |
generate_btn = gr.Button("Generate Image")
|
216 |
output_image = gr.Image(label="Generated Image")
|
217 |
-
|
218 |
-
# Event Handlers
|
219 |
scrape_btn.click(
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
)
|
224 |
|
225 |
train_btn.click(
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
)
|
230 |
|
231 |
generate_btn.click(
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
)
|
236 |
|
237 |
def generate_image(prompt, model_type, runner):
|
@@ -240,12 +234,11 @@ def generate_image(prompt, model_type, runner):
|
|
240 |
image = pipe(prompt).images[0]
|
241 |
else:
|
242 |
model = runner.load_custom()
|
243 |
-
# Simplified generation process
|
244 |
noise = torch.randn(1, CONFIG["training"]["latent_dim"])
|
245 |
-
|
246 |
-
|
247 |
-
image = (
|
248 |
-
|
249 |
return Image.fromarray((image * 255).astype(np.uint8))
|
250 |
|
251 |
if __name__ == "__main__":
|
|
|
10 |
import time
|
11 |
import threading
|
12 |
from PIL import Image
|
|
|
13 |
import numpy as np
|
14 |
|
15 |
# ======================
|
|
|
22 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
23 |
},
|
24 |
"max_images": 100,
|
25 |
+
"scrape_time": 10
|
26 |
},
|
27 |
"training": {
|
28 |
"batch_size": 4,
|
|
|
45 |
def __init__(self):
|
46 |
self.stop_event = threading.Event()
|
47 |
self.scraped_data = []
|
48 |
+
self._lock = threading.Lock()
|
49 |
|
50 |
+
def __getstate__(self):
|
51 |
+
state = self.__dict__.copy()
|
52 |
+
del state['stop_event']
|
53 |
+
del state['_lock']
|
54 |
+
return state
|
55 |
+
|
56 |
+
def __setstate__(self, state):
|
57 |
+
self.__dict__.update(state)
|
58 |
+
self.stop_event = threading.Event()
|
59 |
+
self._lock = threading.Lock()
|
60 |
+
|
61 |
def scrape_images(self, query):
|
62 |
+
with self._lock:
|
63 |
+
search_url = CONFIG["scraping"]["search_url"].format(query=query)
|
64 |
+
try:
|
65 |
+
response = requests.get(search_url, headers=CONFIG["scraping"]["headers"])
|
66 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
67 |
+
img_tags = soup.find_all('img', {'class': 'photo-item__img'})
|
68 |
+
|
69 |
+
for img in img_tags[:CONFIG["scraping"]["max_images"]]:
|
70 |
+
if self.stop_event.is_set():
|
71 |
+
break
|
72 |
+
img_url = img['src']
|
73 |
+
try:
|
74 |
+
img_data = requests.get(img_url).content
|
75 |
+
img_name = f"{int(time.time())}.jpg"
|
76 |
+
img_path = os.path.join(CONFIG["paths"]["dataset_dir"], img_name)
|
77 |
+
|
78 |
+
with open(img_path, 'wb') as f:
|
79 |
+
f.write(img_data)
|
80 |
+
|
81 |
+
self.scraped_data.append({"text": query, "image": img_path})
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Error downloading image: {e}")
|
84 |
+
except Exception as e:
|
85 |
+
print(f"Scraping error: {e}")
|
|
|
86 |
|
87 |
def start_scraping(self, query):
|
88 |
self.stop_event.clear()
|
89 |
+
os.makedirs(CONFIG["paths"]["dataset_dir"], exist_ok=True)
|
|
|
|
|
90 |
thread = threading.Thread(target=self.scrape_images, args=(query,))
|
91 |
thread.start()
|
92 |
return "Scraping started..."
|
|
|
105 |
def __getitem__(self, idx):
|
106 |
item = self.data[idx]
|
107 |
image = Image.open(item["image"]).convert('RGB')
|
|
|
|
|
|
|
|
|
108 |
return {"text": item["text"], "image": image}
|
109 |
|
|
|
110 |
class TextConditionedGenerator(nn.Module):
|
111 |
def __init__(self):
|
112 |
super().__init__()
|
113 |
+
self.text_embedding = nn.Embedding(1000, 128)
|
114 |
self.model = nn.Sequential(
|
115 |
nn.Linear(128 + CONFIG["training"]["latent_dim"], 256),
|
116 |
nn.LeakyReLU(0.2),
|
|
|
124 |
def forward(self, text, noise):
|
125 |
text_emb = self.text_embedding(text)
|
126 |
combined = torch.cat([text_emb, noise], 1)
|
127 |
+
return self.model(combined).view(-1, 3, CONFIG["training"]["img_size"], CONFIG["training"]["img_size"])
|
|
|
128 |
|
129 |
# ======================
|
130 |
# Training Utilities
|
|
|
146 |
criterion = nn.BCELoss()
|
147 |
|
148 |
for epoch in progress.tqdm(range(CONFIG["training"]["epochs"]), desc="Training"):
|
149 |
+
for batch in dataloader:
|
|
|
150 |
real_imgs = batch["image"]
|
151 |
real_labels = torch.ones(real_imgs.size(0), 1)
|
|
|
152 |
noise = torch.randn(real_imgs.size(0), CONFIG["training"]["latent_dim"])
|
|
|
|
|
153 |
|
154 |
+
# Discriminator training
|
155 |
optimizer_D.zero_grad()
|
156 |
real_loss = criterion(discriminator(real_imgs.view(-1, 3*64**2)), real_labels)
|
157 |
+
fake_imgs = generator(torch.randint(0, 1000, (real_imgs.size(0),)), noise)
|
158 |
+
fake_loss = criterion(discriminator(fake_imgs.detach().view(-1, 3*64**2)), torch.zeros_like(real_labels))
|
159 |
+
d_loss = (real_loss + fake_loss) / 2
|
160 |
d_loss.backward()
|
161 |
optimizer_D.step()
|
162 |
|
163 |
+
# Generator training
|
164 |
optimizer_G.zero_grad()
|
165 |
+
g_loss = criterion(discriminator(fake_imgs.view(-1, 3*64**2)), torch.ones_like(real_labels))
|
|
|
166 |
g_loss.backward()
|
167 |
optimizer_G.step()
|
168 |
+
|
169 |
torch.save(generator.state_dict(), CONFIG["paths"]["model_save"])
|
170 |
return "Training completed!"
|
171 |
|
|
|
178 |
self.custom_model = None
|
179 |
|
180 |
def load_pretrained(self):
|
181 |
+
if not self.pretrained_pipe:
|
182 |
+
self.pretrained_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
|
|
|
|
183 |
return self.pretrained_pipe
|
184 |
|
185 |
def load_custom(self):
|
186 |
+
if not self.custom_model:
|
187 |
model = TextConditionedGenerator()
|
188 |
+
model.load_state_dict(torch.load(CONFIG["paths"]["model_save"], map_location='cpu'))
|
189 |
self.custom_model = model
|
190 |
return self.custom_model
|
191 |
|
|
|
193 |
# Gradio Interface
|
194 |
# ======================
|
195 |
with gr.Blocks() as app:
|
196 |
+
scraper_state = gr.State(WebScraper)
|
197 |
+
model_runner_state = gr.State(ModelRunner)
|
|
|
198 |
|
199 |
with gr.Row():
|
200 |
with gr.Column():
|
201 |
query_input = gr.Textbox(label="Search Query")
|
202 |
scrape_btn = gr.Button("Start Scraping")
|
203 |
scrape_status = gr.Textbox(label="Scraping Status")
|
|
|
204 |
train_btn = gr.Button("Start Training")
|
205 |
training_status = gr.Textbox(label="Training Status")
|
206 |
|
207 |
with gr.Column():
|
208 |
prompt_input = gr.Textbox(label="Generation Prompt")
|
209 |
+
model_choice = gr.Radio(["Pretrained", "Custom"], label="Model Type", value="Pretrained")
|
210 |
generate_btn = gr.Button("Generate Image")
|
211 |
output_image = gr.Image(label="Generated Image")
|
212 |
+
|
|
|
213 |
scrape_btn.click(
|
214 |
+
lambda scraper, query: scraper.start_scraping(query),
|
215 |
+
[scraper_state, query_input],
|
216 |
+
scrape_status
|
217 |
)
|
218 |
|
219 |
train_btn.click(
|
220 |
+
lambda scraper: train_model(scraper),
|
221 |
+
[scraper_state],
|
222 |
+
training_status
|
223 |
)
|
224 |
|
225 |
generate_btn.click(
|
226 |
+
lambda prompt, model_type, runner: generate_image(prompt, model_type, runner),
|
227 |
+
[prompt_input, model_choice, model_runner_state],
|
228 |
+
output_image
|
229 |
)
|
230 |
|
231 |
def generate_image(prompt, model_type, runner):
|
|
|
234 |
image = pipe(prompt).images[0]
|
235 |
else:
|
236 |
model = runner.load_custom()
|
|
|
237 |
noise = torch.randn(1, CONFIG["training"]["latent_dim"])
|
238 |
+
with torch.no_grad():
|
239 |
+
fake = model(torch.randint(0, 1000, (1,)), noise)
|
240 |
+
image = fake.squeeze().permute(1, 2, 0).numpy()
|
241 |
+
image = (image + 1) / 2
|
242 |
return Image.fromarray((image * 255).astype(np.uint8))
|
243 |
|
244 |
if __name__ == "__main__":
|