Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -175,17 +175,17 @@ class DiffusionBuilder:
|
|
| 175 |
self.config = None
|
| 176 |
self.pipeline = None
|
| 177 |
self.model_type = None
|
| 178 |
-
def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None, model_type: str = "StableDiffusion"):
|
| 179 |
-
with st.spinner(f"
|
| 180 |
if model_type == "StableDiffusion":
|
| 181 |
-
self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
|
| 182 |
elif model_type == "DDPM":
|
| 183 |
-
self.pipeline = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
|
| 184 |
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)
|
| 185 |
if config:
|
| 186 |
self.config = config
|
| 187 |
self.model_type = model_type
|
| 188 |
-
st.success(f"Diffusion model loaded! 🎨")
|
| 189 |
return self
|
| 190 |
def fine_tune_sft(self, images, texts, epochs=3):
|
| 191 |
dataset = DiffusionDataset(images, texts)
|
|
@@ -339,11 +339,35 @@ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
|
|
| 339 |
st.session_state['model_loaded'] = True
|
| 340 |
st.rerun()
|
| 341 |
|
| 342 |
-
# Tabs
|
| 343 |
-
tab1, tab2, tab3, tab4
|
| 344 |
|
| 345 |
with tab1:
|
| 346 |
-
st.header("Camera Snap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
|
| 348 |
video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
|
| 349 |
cols = st.columns(2)
|
|
@@ -352,24 +376,26 @@ with tab1:
|
|
| 352 |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
|
| 353 |
if cam0_img:
|
| 354 |
filename = generate_filename(0)
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
| 361 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
| 362 |
st.session_state['cam0_frames'] = []
|
| 363 |
for i in range(slice_count):
|
| 364 |
img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
|
| 365 |
if img:
|
| 366 |
filename = generate_filename(f"0_{i}")
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
| 373 |
update_gallery()
|
| 374 |
for frame in st.session_state['cam0_frames']:
|
| 375 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
|
@@ -378,24 +404,26 @@ with tab1:
|
|
| 378 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
| 379 |
if cam1_img:
|
| 380 |
filename = generate_filename(1)
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
| 387 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
| 388 |
st.session_state['cam1_frames'] = []
|
| 389 |
for i in range(slice_count):
|
| 390 |
img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
|
| 391 |
if img:
|
| 392 |
filename = generate_filename(f"1_{i}")
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
|
|
|
| 399 |
update_gallery()
|
| 400 |
for frame in st.session_state['cam1_frames']:
|
| 401 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
|
@@ -444,28 +472,6 @@ with tab2:
|
|
| 444 |
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
| 445 |
|
| 446 |
with tab3:
|
| 447 |
-
st.header("Build Titan 🌱")
|
| 448 |
-
model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
|
| 449 |
-
base_model_options = {
|
| 450 |
-
"Causal LM": ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
|
| 451 |
-
"Diffusion": [
|
| 452 |
-
"OFA-Sys/small-stable-diffusion-v0 (LDM/Conditional)",
|
| 453 |
-
"google/ddpm-ema-celebahq-256 (DDPM/SDE/Autoregressive Proxy)"
|
| 454 |
-
]
|
| 455 |
-
}
|
| 456 |
-
base_model = st.selectbox("Select Tiny Model", base_model_options[model_type])
|
| 457 |
-
model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
|
| 458 |
-
if st.button("Download Model ⬇️"):
|
| 459 |
-
config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model.split(" ")[0], size="small")
|
| 460 |
-
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
| 461 |
-
model_type_for_diffusion = "StableDiffusion" if "small-stable-diffusion" in base_model else "DDPM"
|
| 462 |
-
builder.load_model(base_model.split(" ")[0], config, model_type_for_diffusion)
|
| 463 |
-
builder.save_model(config.model_path)
|
| 464 |
-
st.session_state['builder'] = builder
|
| 465 |
-
st.session_state['model_loaded'] = True
|
| 466 |
-
st.rerun()
|
| 467 |
-
|
| 468 |
-
with tab4:
|
| 469 |
st.header("Test Titan 🧪")
|
| 470 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
| 471 |
st.warning("Please build or load a Titan first! ⚠️")
|
|
@@ -487,7 +493,7 @@ with tab4:
|
|
| 487 |
image = st.session_state['builder'].generate(prompt)
|
| 488 |
st.image(image, caption=f"Generated from {selected_pipeline}")
|
| 489 |
|
| 490 |
-
with
|
| 491 |
st.header("Agentic RAG Party 🌐")
|
| 492 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
| 493 |
st.warning("Please build or load a Titan first! ⚠️")
|
|
|
|
| 175 |
self.config = None
|
| 176 |
self.pipeline = None
|
| 177 |
self.model_type = None
|
| 178 |
+
def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None, model_type: str = "StableDiffusion", download: bool = True):
|
| 179 |
+
with st.spinner(f"{'Downloading' if download else 'Loading'} {model_path}... ⏳"):
|
| 180 |
if model_type == "StableDiffusion":
|
| 181 |
+
self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32, use_safetensors=True, local_files_only=not download).to("cpu")
|
| 182 |
elif model_type == "DDPM":
|
| 183 |
+
self.pipeline = DDPMPipeline.from_pretrained(model_path, torch_dtype=torch.float32, use_safetensors=True, local_files_only=not download).to("cpu")
|
| 184 |
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)
|
| 185 |
if config:
|
| 186 |
self.config = config
|
| 187 |
self.model_type = model_type
|
| 188 |
+
st.success(f"Diffusion model {'downloaded' if download else 'loaded'}! 🎨")
|
| 189 |
return self
|
| 190 |
def fine_tune_sft(self, images, texts, epochs=3):
|
| 191 |
dataset = DiffusionDataset(images, texts)
|
|
|
|
| 339 |
st.session_state['model_loaded'] = True
|
| 340 |
st.rerun()
|
| 341 |
|
| 342 |
+
# Tabs
|
| 343 |
+
tab1, tab2, tab3, tab4 = st.tabs(["Build Titan & Camera Snap 🌱📷", "Fine-Tune Titan 🔧", "Test Titan 🧪", "Agentic RAG Party 🌐"])
|
| 344 |
|
| 345 |
with tab1:
|
| 346 |
+
st.header("Build Titan & Camera Snap 🌱📷")
|
| 347 |
+
st.subheader("Build Titan 🌱")
|
| 348 |
+
model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
|
| 349 |
+
base_model_options = {
|
| 350 |
+
"Causal LM": ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"],
|
| 351 |
+
"Diffusion": [
|
| 352 |
+
"OFA-Sys/small-stable-diffusion-v0 (LDM/Conditional, ~300 MB)",
|
| 353 |
+
"google/ddpm-ema-celebahq-256 (DDPM/SDE/Autoregressive Proxy, ~280 MB)"
|
| 354 |
+
]
|
| 355 |
+
}
|
| 356 |
+
base_model = st.selectbox("Select Tiny Model", base_model_options[model_type])
|
| 357 |
+
action = st.radio("Action", ["Use Model", "Download Model"], index=0 if "Causal LM" in model_type else 1)
|
| 358 |
+
model_name = st.text_input("Model Name (for Download)", f"tiny-titan-{int(time.time())}") if action == "Download Model" else None
|
| 359 |
+
if st.button(f"{action} ⬇️"):
|
| 360 |
+
config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name or base_model.split(" ")[0], base_model=base_model.split(" ")[0], size="small")
|
| 361 |
+
builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
|
| 362 |
+
model_type_for_diffusion = "StableDiffusion" if "small-stable-diffusion" in base_model else "DDPM"
|
| 363 |
+
builder.load_model(base_model.split(" ")[0], config, model_type_for_diffusion, download=action == "Download Model")
|
| 364 |
+
if action == "Download Model":
|
| 365 |
+
builder.save_model(config.model_path)
|
| 366 |
+
st.session_state['builder'] = builder
|
| 367 |
+
st.session_state['model_loaded'] = True
|
| 368 |
+
st.rerun()
|
| 369 |
+
|
| 370 |
+
st.subheader("Camera Snap 📷")
|
| 371 |
slice_count = st.number_input("Image Slice Count", min_value=1, max_value=20, value=10)
|
| 372 |
video_length = st.number_input("Video Length (seconds)", min_value=1, max_value=30, value=10)
|
| 373 |
cols = st.columns(2)
|
|
|
|
| 376 |
cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
|
| 377 |
if cam0_img:
|
| 378 |
filename = generate_filename(0)
|
| 379 |
+
if filename not in st.session_state['captured_images']:
|
| 380 |
+
with open(filename, "wb") as f:
|
| 381 |
+
f.write(cam0_img.getvalue())
|
| 382 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
| 383 |
+
logger.info(f"Saved snapshot from Camera 0: {filename}")
|
| 384 |
+
st.session_state['captured_images'].append(filename)
|
| 385 |
+
update_gallery()
|
| 386 |
if st.button(f"Capture {slice_count} Frames - Cam 0 📸"):
|
| 387 |
st.session_state['cam0_frames'] = []
|
| 388 |
for i in range(slice_count):
|
| 389 |
img = st.camera_input(f"Frame {i} - Cam 0", key=f"cam0_frame_{i}_{time.time()}")
|
| 390 |
if img:
|
| 391 |
filename = generate_filename(f"0_{i}")
|
| 392 |
+
if filename not in st.session_state['captured_images']:
|
| 393 |
+
with open(filename, "wb") as f:
|
| 394 |
+
f.write(img.getvalue())
|
| 395 |
+
st.session_state['cam0_frames'].append(filename)
|
| 396 |
+
logger.info(f"Saved frame {i} from Camera 0: {filename}")
|
| 397 |
+
time.sleep(1.0 / slice_count)
|
| 398 |
+
st.session_state['captured_images'].extend([f for f in st.session_state['cam0_frames'] if f not in st.session_state['captured_images']])
|
| 399 |
update_gallery()
|
| 400 |
for frame in st.session_state['cam0_frames']:
|
| 401 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
|
|
|
| 404 |
cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
|
| 405 |
if cam1_img:
|
| 406 |
filename = generate_filename(1)
|
| 407 |
+
if filename not in st.session_state['captured_images']:
|
| 408 |
+
with open(filename, "wb") as f:
|
| 409 |
+
f.write(cam1_img.getvalue())
|
| 410 |
+
st.image(Image.open(filename), caption=filename, use_container_width=True)
|
| 411 |
+
logger.info(f"Saved snapshot from Camera 1: {filename}")
|
| 412 |
+
st.session_state['captured_images'].append(filename)
|
| 413 |
+
update_gallery()
|
| 414 |
if st.button(f"Capture {slice_count} Frames - Cam 1 📸"):
|
| 415 |
st.session_state['cam1_frames'] = []
|
| 416 |
for i in range(slice_count):
|
| 417 |
img = st.camera_input(f"Frame {i} - Cam 1", key=f"cam1_frame_{i}_{time.time()}")
|
| 418 |
if img:
|
| 419 |
filename = generate_filename(f"1_{i}")
|
| 420 |
+
if filename not in st.session_state['captured_images']:
|
| 421 |
+
with open(filename, "wb") as f:
|
| 422 |
+
f.write(img.getvalue())
|
| 423 |
+
st.session_state['cam1_frames'].append(filename)
|
| 424 |
+
logger.info(f"Saved frame {i} from Camera 1: {filename}")
|
| 425 |
+
time.sleep(1.0 / slice_count)
|
| 426 |
+
st.session_state['captured_images'].extend([f for f in st.session_state['cam1_frames'] if f not in st.session_state['captured_images']])
|
| 427 |
update_gallery()
|
| 428 |
for frame in st.session_state['cam1_frames']:
|
| 429 |
st.image(Image.open(frame), caption=frame, use_container_width=True)
|
|
|
|
| 472 |
st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
|
| 473 |
|
| 474 |
with tab3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
st.header("Test Titan 🧪")
|
| 476 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
| 477 |
st.warning("Please build or load a Titan first! ⚠️")
|
|
|
|
| 493 |
image = st.session_state['builder'].generate(prompt)
|
| 494 |
st.image(image, caption=f"Generated from {selected_pipeline}")
|
| 495 |
|
| 496 |
+
with tab4:
|
| 497 |
st.header("Agentic RAG Party 🌐")
|
| 498 |
if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
|
| 499 |
st.warning("Please build or load a Titan first! ⚠️")
|