ImedHa commited on
Commit
ee412eb
·
verified ·
1 Parent(s): 02bb97a

Upload 7 files

Browse files
Files changed (6) hide show
  1. about_page.py +148 -0
  2. app.py +29 -128
  3. datasets_page.py +91 -0
  4. main_dashboard.py +37 -0
  5. s2-swinunetr-weights.pth +3 -0
  6. system_test_page.py +262 -0
about_page.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+
4
+ def show():
5
+ st.markdown('<div class="main-header">ℹ️ About This Project</div>', unsafe_allow_html=True)
6
+
7
+ # ACVSS Hackathon Information
8
+ st.markdown("## ACVSS 2025 Summer School Hackathon Project")
9
+
10
+ st.info(
11
+ "This project was developed by **Team SATOR** as part of the **ACVSS 2025 - The 4th Summer School on Advanced Computer Vision** hackathon. "
12
+ "Our goal was to build a functional prototype for surgical scene understanding in a limited time frame."
13
+ )
14
+
15
+ # ACVSS Description
16
+ st.markdown("""
17
+ ### About ACVSS
18
+
19
+ The **African Computer Vision Summer School (ACVSS)** is an intensive program designed to advance computer vision research and applications across Africa. The summer school brings together researchers, students, and industry professionals to explore cutting-edge technologies in computer vision, machine learning, and artificial intelligence.
20
+
21
+ **Learn more**: [acvss.ai](https://www.acvss.ai/) | **Year**: 2025 | **Edition**: 4th Summer School
22
+ """)
23
+
24
+ st.markdown("---")
25
+
26
+ # Team Section
27
+ st.markdown("## 👥 Meet Team SATOR")
28
+
29
+ # Add team description
30
+ st.markdown("""
31
+ **Team SATOR** is a diverse group of professionals brought together for the ACVSS 2025 hackathon.
32
+ Our team combines expertise in AI/ML, software engineering, data science, and quality assurance to deliver
33
+ innovative solutions in surgical scene understanding.
34
+ """)
35
+
36
+ st.markdown("### Team Members")
37
+
38
+ # Team Member Profiles
39
+ team_members = [
40
+ {
41
+ "name": "MEM1",
42
+ "role": "Team Lead & System Architect",
43
+ "desc": "Led the project, designed the overall system architecture, and ensured seamless integration of all components. Her vision guided the project's success.",
44
+ "email": "[email protected]",
45
+ "linkedin": "https://www.linkedin.com/in/evelyn-reed-acvss",
46
+ "github": "https://github.com/evelyn-reed",
47
+ "img": "https://i.pravatar.cc/150?img=1"
48
+ },
49
+ {
50
+ "name": "MEM2",
51
+ "role": "AI/ML Specialist",
52
+ "desc": "Focused on developing and training the core SwinUnet and scene understanding models. Responsible for the AI-powered analysis and insights.",
53
+ "email": "[email protected]",
54
+ "linkedin": "https://www.linkedin.com/in/kenji-tanaka-ml",
55
+ "github": "https://github.com/kenji-tanaka",
56
+ "img": "https://i.pravatar.cc/150?img=2"
57
+ },
58
+ {
59
+ "name": "MEM3",
60
+ "role": "UI/UX & Frontend Developer",
61
+ "desc": "Designed and built the Streamlit dashboard, focusing on creating an intuitive and informative user interface for surgeons and researchers.",
62
+ "email": "[email protected]",
63
+ "linkedin": "https://www.linkedin.com/in/sofia-rossi-ui",
64
+ "github": "https://github.com/sofia-rossi",
65
+ "img": "https://i.pravatar.cc/150?img=3"
66
+ },
67
+ {
68
+ "name": "MEM4",
69
+ "role": "Data Engineer",
70
+ "desc": "Managed the data pipeline, from processing the MM-OR dataset to ensuring the models received clean, well-structured data for training and testing.",
71
+ "email": "[email protected]",
72
+ "linkedin": "https://www.linkedin.com/in/david-chen-data",
73
+ "github": "https://github.com/david-chen",
74
+ "img": "https://i.pravatar.cc/150?img=4"
75
+ },
76
+ {
77
+ "name": "MEM5",
78
+ "role": "QA & Testing Lead",
79
+ "desc": "Oversaw the testing and validation of the entire pipeline, ensuring the system was robust, accurate, and met the project's objectives.",
80
+ "email": "[email protected]",
81
+ "linkedin": "https://www.linkedin.com/in/aisha-bello-qa",
82
+ "github": "https://github.com/aisha-bello",
83
+ "img": "https://i.pravatar.cc/150?img=5"
84
+ }
85
+ ]
86
+
87
+ # Display team members in columns
88
+ # Display team members in a responsive grid
89
+ cols = st.columns(5)
90
+ for i, member in enumerate(team_members):
91
+ with cols[i]:
92
+ st.markdown(f"##### {member['name']}")
93
+ st.image(member['img'], width=120)
94
+ st.markdown(f"**{member['role']}**")
95
+ st.caption(member['desc'])
96
+ st.markdown(f"✉️ [{member['email']}](mailto:{member['email']})")
97
+ st.markdown(f"💼 [LinkedIn]({member['linkedin']})")
98
+ st.markdown(f"💻 [GitHub]({member['github']})")
99
+
100
+ st.markdown("---")
101
+
102
+ # Project Overview Section
103
+ st.markdown("## 🎯 Project Overview")
104
+
105
+ col1, col2 = st.columns(2)
106
+
107
+ with col1:
108
+ st.markdown("""
109
+ ### 🏥 Video Surgical Scene Understanding
110
+
111
+ Our project focuses on developing an advanced computer vision system capable of:
112
+
113
+ - **Scene Analysis**: Understanding surgical environments
114
+ - **Tool Recognition**: Identifying medical instruments
115
+ - **Workflow Tracking**: Monitoring surgical procedures
116
+ - **Real-time Processing**: Immediate analysis and feedback
117
+ """)
118
+
119
+ with col2:
120
+ st.markdown("""
121
+ ### 🛠️ Technical Stack
122
+
123
+ - **Frontend**: Streamlit Dashboard
124
+ - **Backend**: Python
125
+ - **ML Models**: SwinUnet, Scene Graphs
126
+ - **Dataset**: MM-OR (Multimodal Operating Room)
127
+ - **Version**: v1.0 (July 2025)
128
+ """)
129
+
130
+ st.markdown("---")
131
+
132
+ # Hackathon Achievement Section
133
+ st.markdown("## 🏆 Hackathon Achievement")
134
+
135
+ achievement_col1, achievement_col2, achievement_col3 = st.columns(3)
136
+
137
+ with achievement_col1:
138
+ st.metric("Pipeline Version", "v1.0", "Completed")
139
+
140
+ with achievement_col2:
141
+ st.metric("Models Integrated", "2/2", "✅ Working")
142
+
143
+ with achievement_col3:
144
+ st.metric("Development Time", "Hackathon", "July 2025")
145
+
146
+ st.markdown("---")
147
+
148
+ st.markdown("© 2025 Team SATOR - ACVSS Hackathon. All Rights Reserved.")
app.py CHANGED
@@ -1,128 +1,29 @@
1
-
2
- import streamlit as st
3
- from PIL import Image
4
- import torch
5
- import os
6
- from io import StringIO
7
- import sys
8
-
9
- # --- TorchDynamo Fix for Unsloth/MedGemma ---
10
- import torch._dynamo
11
- torch._dynamo.config.capture_scalar_outputs = True
12
- torch.compiler.disable()
13
-
14
- # --- Dependency Handling ---
15
- try:
16
- from unsloth import FastVisionModel
17
- from transformers import TextStreamer
18
- except ImportError as e:
19
- st.error(f"A required library is not installed. Please install dependencies. Error: {e}")
20
- st.stop()
21
-
22
- @st.cache_resource
23
- def load_medgemma_model():
24
- """Loads the MedGemma vision-language model in eager mode."""
25
- try:
26
- model, processor = FastVisionModel.from_pretrained(
27
- "fiqqy/MedGemma-MM-OR-FT10",
28
- load_in_4bit=False,
29
- use_gradient_checkpointing="unsloth",
30
- )
31
- return model, processor
32
- except Exception as e:
33
- st.error(f"Error loading MedGemma model: {e}")
34
- return None, None
35
-
36
- def run_captioning(medgemma_model, processor, frames, instruction):
37
- """Runs MedGemma inference using 3 frames and an instruction."""
38
- st.write("Preparing inputs for MedGemma...")
39
- images = [f.convert("RGB") for f in frames]
40
- messages = [
41
- {"role": "user", "content": [
42
- {"type": "image"}, {"type": "image"}, {"type": "image"},
43
- {"type": "text", "text": instruction},
44
- ]},
45
- ]
46
- input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
47
- device = "cuda" if torch.cuda.is_available() else "cpu"
48
- inputs = processor(
49
- images, input_text, add_special_tokens=False, return_tensors="pt",
50
- ).to(device)
51
-
52
- text_streamer = TextStreamer(processor, skip_prompt=True)
53
- old_stdout = sys.stdout
54
- sys.stdout = captured_output = StringIO()
55
-
56
- st.write("Running MedGemma Analysis...")
57
- torch._dynamo.disable()
58
- medgemma_model.generate(
59
- **inputs, streamer=text_streamer, max_new_tokens=768,
60
- use_cache=True, temperature=1.0, top_p=0.95, top_k=64
61
- )
62
-
63
- sys.stdout = old_stdout
64
- result = captured_output.getvalue()
65
- return result
66
-
67
- def show():
68
- """Main function to render the Streamlit UI."""
69
- st.title("MedGemma Scene Analysis System")
70
- st.write("A system to test MedGemma vision-language captioning model.")
71
-
72
- st.header("1. Load MedGemma Model")
73
- if "medgemma_model" not in st.session_state:
74
- st.session_state.medgemma_model, st.session_state.processor = None, None
75
- if st.button("Load MedGemma Model"):
76
- with st.spinner("Loading MedGemma... This can take several minutes."):
77
- st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model()
78
-
79
- if st.session_state.get("medgemma_model") and st.session_state.get("processor"):
80
- st.success("MedGemma model is loaded.")
81
- else:
82
- st.warning("MedGemma model is not loaded.")
83
-
84
- st.header("2. Upload Data")
85
- st.subheader("Upload Three Sequential Surgical Video Frames")
86
- col1, col2, col3 = st.columns(3)
87
- uploaded_files = [
88
- col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"),
89
- col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"),
90
- col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3")
91
- ]
92
- frames = [Image.open(f) for f in uploaded_files if f is not None]
93
-
94
- display_size = (256, 256)
95
- if len(frames) == 3:
96
- st.success("All three frames have been uploaded successfully.")
97
- img_cols = st.columns(3)
98
- for i, frame in enumerate(frames):
99
- img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True)
100
- else:
101
- st.info("Please upload all three frames to proceed.")
102
-
103
- st.header("3. Generate Scene Analysis")
104
- instruction_prompt = st.text_area(
105
- "Enter your custom instruction prompt:",
106
- "Provide a detailed summary of the surgical action, noting the instruments used and their interactions."
107
- )
108
-
109
- can_run_analysis = (
110
- st.session_state.get("medgemma_model") is not None and
111
- len(frames) == 3 and
112
- bool(instruction_prompt)
113
- )
114
-
115
- if st.button("Run Analysis", disabled=not can_run_analysis):
116
- with st.spinner("Running MedGemma analysis... This may take a moment."):
117
- result = run_captioning(
118
- st.session_state.medgemma_model, st.session_state.processor,
119
- frames, instruction_prompt
120
- )
121
- st.subheader("Analysis Result")
122
- st.write(result)
123
-
124
- if not can_run_analysis:
125
- st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, and a prompt is provided.")
126
-
127
- if __name__ == "__main__":
128
- show()
 
1
+ import streamlit as st
2
+ import main_dashboard
3
+ import about_page
4
+ import datasets_page
5
+ import system_test_page
6
+
7
+ st.set_page_config(page_title="Surgical Scene Understanding", page_icon="🩺", layout="wide")
8
+
9
+ with st.sidebar:
10
+ st.markdown("## 🩺 Surgical Scene Understanding")
11
+ page = st.radio(
12
+ "Navigation",
13
+ [
14
+ "🏠 Main Dashboard",
15
+ "🧪 Test System",
16
+ "📂 Dataset",
17
+ "ℹ️ About"
18
+ ],
19
+ label_visibility="collapsed"
20
+ )
21
+
22
+ if page.startswith("🏠"):
23
+ main_dashboard.show()
24
+ elif page.startswith("🧪"):
25
+ system_test_page.show()
26
+ elif page.startswith("📂"):
27
+ datasets_page.show()
28
+ elif page.startswith("ℹ️"):
29
+ about_page.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets_page.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+
7
+ def show():
8
+ st.markdown('<div class="main-header">📁 Dataset: MM-OR</div>', unsafe_allow_html=True)
9
+
10
+ st.markdown("---")
11
+
12
+ st.markdown("## 🗂️ MM-OR: A Large-scale Multimodal Operating Room Dataset")
13
+ st.markdown("""
14
+ This project utilizes the **MM-OR** dataset, a comprehensive collection of data recorded in a realistic operating room environment.
15
+ It is designed to support research in surgical workflow analysis, human activity recognition, and context-aware systems in healthcare.
16
+ """)
17
+
18
+ # Dataset overview
19
+ st.markdown("### 📊 Dataset High-Level Statistics")
20
+
21
+ col1, col2, col3, col4 = st.columns(4)
22
+
23
+ with col1:
24
+ st.metric(
25
+ label="📹 Surgical Procedures",
26
+ value="10",
27
+ )
28
+
29
+ with col2:
30
+ st.metric(
31
+ label="⏱️ Total Duration",
32
+ value=">100 hours",
33
+ )
34
+
35
+ with col3:
36
+ st.metric(
37
+ label="🏷️ Modalities",
38
+ value="3 (Video, Audio, Depth)",
39
+ )
40
+
41
+ with col4:
42
+ st.metric(
43
+ label="📂 Total Size",
44
+ value="~12 TB",
45
+ )
46
+
47
+ st.markdown("---")
48
+
49
+ # Dataset categories
50
+ st.markdown("### 🏥 Dataset Details")
51
+
52
+ st.info("The MM-OR dataset is the primary source of data for training and evaluating the models in this system.")
53
+
54
+ col1, col2 = st.columns(2)
55
+
56
+ with col1:
57
+ st.markdown("#### Key Features")
58
+ st.markdown("""
59
+ - **Multimodal Data**: Includes synchronized video, multi-channel audio, and depth information.
60
+ - **Multiple Views**: Video captured from multiple camera perspectives to provide a comprehensive view of the operating room.
61
+ - **Rich Annotations**: Detailed annotations of:
62
+ - Surgical roles (e.g., primary surgeon, assistant, nurse).
63
+ - Atomic actions and complex activities.
64
+ - Interactions between team members.
65
+ - **Realistic Environment**: Data was collected in a high-fidelity simulated operating room.
66
+ """)
67
+
68
+ with col2:
69
+ st.markdown("#### Data Modalities")
70
+ st.image("https://www.researchgate.net/publication/359174963/figure/fig1/AS:1143128108556288@1649553881835/An-overview-of-our-data-acquisition-system-in-the-operating-room-OR-We-record.jpg",
71
+ caption="Overview of the data acquisition system in the operating room.")
72
+
73
+ st.markdown("---")
74
+ st.markdown("### 📈 Data Distribution")
75
+
76
+ # Create sample data for visualization
77
+ procedure_data = {
78
+ 'Surgical Procedure': [f'Procedure {i+1}' for i in range(10)],
79
+ 'Duration (hours)': np.random.uniform(8, 12, 10).round(1),
80
+ 'Number of Annotations': np.random.randint(1500, 3000, 10)
81
+ }
82
+ df_procedures = pd.DataFrame(procedure_data)
83
+
84
+ fig = px.bar(df_procedures, x='Surgical Procedure', y='Duration (hours)',
85
+ title='Duration per Surgical Procedure',
86
+ labels={'Duration (hours)': 'Duration (hours)'},
87
+ color='Surgical Procedure')
88
+ st.plotly_chart(fig, use_container_width=True)
89
+
90
+ st.markdown("For more information, please refer to the original publication: *MM-OR: A Large-scale Multimodal Operating Room Dataset for Human Activity Recognition*.")
91
+ st.markdown("The dataset is available on GitHub: [MM-OR Dataset](https://github.com/egeozsoy/MM-OR)")
main_dashboard.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def show():
4
+ st.markdown('<div class="main-header">🏥 Video Surgical Scene Understanding Dashboard</div>', unsafe_allow_html=True)
5
+ st.markdown("---")
6
+
7
+ # Welcome and overall description
8
+ st.markdown("## Welcome to the Surgical Scene Analysis Platform")
9
+ st.markdown("""
10
+ This platform demonstrates an end-to-end pipeline for automated understanding of surgical scenes from video data.
11
+ The system leverages advanced computer vision and AI models to analyze surgical workflows, recognize tools, and generate scene-level captions.
12
+ Navigate through the sidebar to test the system, explore datasets, or learn more about the project.
13
+ """)
14
+
15
+ st.markdown("---")
16
+ st.markdown("## 🔄 Pipeline Overview")
17
+ st.markdown("""
18
+ The surgical scene understanding pipeline consists of the following main steps:
19
+ 1. **Frame Extraction**: Select or upload three consecutive frames from a surgical video.
20
+ 2. **Segmentation**: Use the SwinUNETR model to generate a segmentation mask for the scene.
21
+ 3. **Captioning**: Input the frames and mask into the MedGemma model to generate a descriptive caption or scene graph.
22
+ 4. **Results & Analysis**: Review the generated mask and caption to understand the surgical context.
23
+ """)
24
+
25
+ st.markdown("---")
26
+ st.markdown("## 📚 Project Description")
27
+ st.markdown("""
28
+ This project was developed by **Team SATOR** for the ACVSS 2025 Hackathon.
29
+ Our goal is to provide an accessible, interactive demonstration of state-of-the-art surgical scene understanding using deep learning.
30
+ - **Frontend**: Streamlit Dashboard
31
+ - **Backend**: Python, PyTorch, MONAI, HuggingFace Transformers
32
+ - **Models**: SwinUNETR (segmentation), MedGemma (captioning)
33
+ - **Dataset**: MM-OR (Multimodal Operating Room)
34
+ """)
35
+
36
+ st.markdown("---")
37
+ st.info("Use the sidebar to start testing the system or to learn more about the dataset and team.")
s2-swinunetr-weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af70d2fd82d8184036623e936723bca2c80305b3b2b4e6d3c32692adc17866c7
3
+ size 114911598
system_test_page.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ import os
6
+ from io import StringIO
7
+ import sys
8
+ import torch.nn as nn
9
+
10
+ # --- TorchDynamo Fix for Unsloth/MedGemma ---
11
+ import torch._dynamo
12
+ torch._dynamo.config.capture_scalar_outputs = True
13
+
14
+ # --- DEFINITIVE FIX FOR JIT COMPILER ERRORS ---
15
+ torch.compiler.disable()
16
+
17
+ # --- Dependency Handling ---
18
+ try:
19
+ from monai.networks.nets import SwinUNETR
20
+ import torchvision.transforms as T
21
+ from unsloth import FastVisionModel
22
+ from transformers import TextStreamer
23
+ from s2wrapper import forward as multiscale_forward
24
+ except ImportError as e:
25
+ st.error(f"A required library is not installed. Please install dependencies. Error: {e}")
26
+ st.stop()
27
+
28
+ # --- Config and Model Definition ---
29
+ class Config:
30
+ ORIGINAL_LABELS = [0,3,6,9,12,15,18,21,24,27,30,33,36,39,42,45,48,51,54,57,60]
31
+ LABEL_MAP = {val: i for i, val in enumerate(ORIGINAL_LABELS)}
32
+ NUM_CLASSES = len(ORIGINAL_LABELS)
33
+ IMG_SIZE = (256, 256)
34
+ FEATURE_SIZE = 48
35
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ class multiscaleSwinUNETR(nn.Module):
38
+ def __init__(self, num_classes, scales=[1]):
39
+ super().__init__()
40
+ self.scales = scales
41
+ self.num_classes = num_classes
42
+ self.model = SwinUNETR(
43
+ spatial_dims=2,
44
+ in_channels=3,
45
+ out_channels=num_classes,
46
+ feature_size=Config.FEATURE_SIZE,
47
+ drop_rate=0.0,
48
+ attn_drop_rate=0.0,
49
+ dropout_path_rate=0.0,
50
+ use_checkpoint=True,
51
+ use_v2=True
52
+ )
53
+ self.segmentation_head = nn.Sequential(
54
+ nn.Conv2d(len(scales)*num_classes, num_classes, 3, padding=1),
55
+ nn.BatchNorm2d(num_classes),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv2d(num_classes, num_classes, 1)
58
+ )
59
+ def forward(self, x):
60
+ outs = multiscale_forward(self.model, x, scales=self.scales, output_shape="bchw")
61
+ if isinstance(outs, (list, tuple)):
62
+ normed = []
63
+ for f in outs:
64
+ f = f / (f.std(dim=(2, 3), keepdim=True) + 1e-6)
65
+ normed.append(f)
66
+ feats = torch.cat(normed, dim=1)
67
+ elif isinstance(outs, torch.Tensor) and outs.dim() == 4:
68
+ if len(self.scales) == 1:
69
+ return outs
70
+ feats = outs / (outs.std(dim=(2, 3), keepdim=True) + 1e-6)
71
+ else:
72
+ raise ValueError(f"Unexpected output shape/type from multiscale_forward: {type(outs)}, {getattr(outs,'shape',None)}")
73
+ logits = self.segmentation_head(feats)
74
+ return logits
75
+
76
+ # --- Model Loading ---
77
+ @st.cache_resource
78
+ def load_swinunetr_model():
79
+ """Loads the multiscale SwinUNETR segmentation model."""
80
+ model_path = 's2-swinunetr-weights.pth'
81
+ if not os.path.exists(model_path):
82
+ st.error(f"Segmentation model file not found at {model_path}")
83
+ return None, None
84
+ try:
85
+ model = multiscaleSwinUNETR(num_classes=Config.NUM_CLASSES, scales=[1])
86
+ model.load_state_dict(torch.load(model_path, map_location=Config.DEVICE))
87
+ model.eval()
88
+ return model, Config
89
+ except Exception as e:
90
+ st.error(f"Error loading segmentation model: {e}")
91
+ return None, None
92
+
93
+ @st.cache_resource
94
+ def load_medgemma_model():
95
+ """Loads the MedGemma vision-language model in eager mode."""
96
+ try:
97
+ model, processor = FastVisionModel.from_pretrained(
98
+ "fiqqy/MedGemma-MM-OR-FT10",
99
+ load_in_4bit=False,
100
+ use_gradient_checkpointing="unsloth",
101
+ )
102
+ return model, processor
103
+ except Exception as e:
104
+ st.error(f"Error loading MedGemma model: {e}")
105
+ return None, None
106
+
107
+ # --- Preprocessing ---
108
+ def preprocess_frames(frames, config):
109
+ """Prepares image frames for the segmentation model."""
110
+ transform = T.Compose([
111
+ T.Resize(config.IMG_SIZE, antialias=True),
112
+ T.ToTensor(),
113
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
114
+ ])
115
+ tensors = [transform(frame.convert("RGB")) for frame in frames]
116
+ batch = torch.stack(tensors)
117
+ return batch
118
+
119
+ # --- Color Palette for Mask Visualization ---
120
+ def make_palette(num_classes):
121
+ rng = np.random.default_rng(0)
122
+ colors = rng.integers(0, 255, size=(num_classes, 3), dtype=np.uint8)
123
+ colors[0] = np.array([0, 0, 0])
124
+ return colors
125
+
126
+ # --- Inference ---
127
+ def run_segmentation(model, config, frames):
128
+ """Runs segmentation on the uploaded frames and visualizes with a color palette."""
129
+ st.write("Running segmentation...")
130
+ batch = preprocess_frames(frames, config)
131
+ device = config.DEVICE
132
+ batch = batch.to(device)
133
+ model = model.to(device)
134
+ with torch.no_grad():
135
+ logits = model(batch)
136
+ preds = torch.argmax(logits, 1).cpu().numpy()
137
+ mask = preds[0]
138
+ st.write(f"Mask unique values: {np.unique(mask)}")
139
+ palette = make_palette(config.NUM_CLASSES)
140
+ color_mask = palette[mask]
141
+ mask_img = Image.fromarray(color_mask.astype(np.uint8))
142
+ return mask_img
143
+
144
+ # --- MedGemma Captioning ---
145
+ def run_captioning(medgemma_model, processor, frames, mask_img, instruction):
146
+ """Runs MedGemma inference using 3 frames, 1 mask, and an instruction."""
147
+ st.write("Preparing inputs for MedGemma...")
148
+ images = [f.convert("RGB") for f in frames]
149
+ mask_img = mask_img.convert("RGB")
150
+ messages = [
151
+ {"role": "user", "content": [
152
+ {"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"},
153
+ {"type": "text", "text": instruction},
154
+ ]},
155
+ ]
156
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
157
+ device = "cuda" if torch.cuda.is_available() else "cpu"
158
+ all_images = images + [mask_img]
159
+ inputs = processor(
160
+ all_images, input_text, add_special_tokens=False, return_tensors="pt",
161
+ ).to(device)
162
+
163
+ text_streamer = TextStreamer(processor, skip_prompt=True)
164
+ old_stdout = sys.stdout
165
+ sys.stdout = captured_output = StringIO()
166
+
167
+ st.write("Running MedGemma Analysis...")
168
+ torch._dynamo.disable()
169
+ medgemma_model.generate(
170
+ **inputs, streamer=text_streamer, max_new_tokens=768,
171
+ use_cache=True, temperature=1.0, top_p=0.95, top_k=64
172
+ )
173
+
174
+ sys.stdout = old_stdout
175
+ result = captured_output.getvalue()
176
+ return result
177
+
178
+ # --- Streamlit UI ---
179
+ def show():
180
+ """Main function to render the Streamlit UI."""
181
+ st.title("Surgical Scene Analysis System")
182
+ st.write("A system to test surgical scene segmentation and captioning models.")
183
+
184
+ st.header("1. Load Models")
185
+ if "seg_model" not in st.session_state or "seg_config" not in st.session_state:
186
+ st.session_state.seg_model, st.session_state.seg_config = None, None
187
+ if st.button("Load Segmentation Model"):
188
+ with st.spinner("Loading SwinUNETR..."):
189
+ st.session_state.seg_model, st.session_state.seg_config = load_swinunetr_model()
190
+
191
+ if st.session_state.seg_model is not None:
192
+ st.success("Segmentation model is loaded.")
193
+ else:
194
+ st.warning("Segmentation model is not loaded.")
195
+
196
+ if "medgemma_model" not in st.session_state:
197
+ st.session_state.medgemma_model, st.session_state.processor = None, None
198
+ if st.button("Load MedGemma Model"):
199
+ with st.spinner("Loading MedGemma... This can take several minutes."):
200
+ st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model()
201
+
202
+ if st.session_state.get("medgemma_model") and st.session_state.get("processor"):
203
+ st.success("MedGemma model is loaded.")
204
+ else:
205
+ st.warning("MedGemma model is not loaded.")
206
+
207
+ st.header("2. Upload Data & Generate Mask")
208
+ st.subheader("Upload Three Sequential Surgical Video Frames")
209
+ col1, col2, col3 = st.columns(3)
210
+ uploaded_files = [
211
+ col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"),
212
+ col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"),
213
+ col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3")
214
+ ]
215
+ frames = [Image.open(f) for f in uploaded_files if f is not None]
216
+
217
+ display_size = (256, 256)
218
+ if "mask_img" not in st.session_state:
219
+ st.session_state.mask_img = None
220
+
221
+ if len(frames) == 3:
222
+ st.success("All three frames have been uploaded successfully.")
223
+ img_cols = st.columns(4)
224
+ for i, frame in enumerate(frames):
225
+ img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True)
226
+
227
+ if st.session_state.seg_model and st.session_state.seg_config and st.button("Run Segmentation"):
228
+ with st.spinner("Generating segmentation mask..."):
229
+ st.session_state.mask_img = run_segmentation(st.session_state.seg_model, st.session_state.seg_config, frames)
230
+
231
+ if st.session_state.mask_img is not None:
232
+ img_cols[3].image(st.session_state.mask_img.resize(display_size), caption="Segmentation Mask", use_container_width=True)
233
+ else:
234
+ st.info("Please upload all three frames to proceed.")
235
+
236
+ st.header("3. Generate Scene Analysis")
237
+ instruction_prompt = st.text_area(
238
+ "Enter your custom instruction prompt:",
239
+ "Provide a detailed summary of the surgical action, noting the instruments used and their interactions."
240
+ )
241
+
242
+ can_run_analysis = (
243
+ st.session_state.get("medgemma_model") is not None and
244
+ len(frames) == 3 and
245
+ st.session_state.get("mask_img") is not None and
246
+ bool(instruction_prompt)
247
+ )
248
+
249
+ if st.button("Run Analysis", disabled=not can_run_analysis):
250
+ with st.spinner("Running MedGemma analysis... This may take a moment."):
251
+ result = run_captioning(
252
+ st.session_state.medgemma_model, st.session_state.processor,
253
+ frames, st.session_state.mask_img, instruction_prompt
254
+ )
255
+ st.subheader("Analysis Result")
256
+ st.write(result)
257
+
258
+ if not can_run_analysis:
259
+ st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, segmentation is complete, and a prompt is provided.")
260
+
261
+ if __name__ == "__main__":
262
+ show()