SAM-2-RedSpiderMites / streamlit_app.py
jithin14's picture
Debug: Process only the first sampled frame
76fe51c
import streamlit as st
import torch
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image, ImageDraw
import numpy as np
import cv2 # OpenCV for video processing
import os
import sys
import time
import tempfile
from torchvision.ops import nms
# Imports for manual model loading
from omegaconf import OmegaConf
from hydra.utils import instantiate
# Need _load_checkpoint from the local sam2 copy
from sam2.build_sam import _load_checkpoint
from sam2 import automatic_mask_generator
# --- Configuration ---
# Assuming paths relative to the app file location
_APP_DIR = os.path.dirname(os.path.abspath(__file__))
SAM_CHECKPOINT_PATH = os.path.join(_APP_DIR, "checkpoints", "sam2.1_hiera_large.pt")
SAM_MODEL_CONFIG_PATH = os.path.join(_APP_DIR, "sam2", "configs", "sam2.1", "sam2.1_hiera_l.yaml")
CLASSIFICATION_MODEL_PATH = os.path.join(_APP_DIR, "models", "best_mobilenet_v3_small.pth")
NUM_CLASSES = 2
CLASS_NAMES = ['Not Infected', 'Infected']
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Model Loading (Cached with Streamlit) ---
@st.cache_resource
def load_models():
sam_model = None
mask_generator = None
classification_model = None
print("--- Loading Models (cached) ---") # Add print to see when cache is missed
# --- Load SAM Model Manually ---
print("Loading SAM model manually...")
try:
# 1. Check paths
if not os.path.exists(SAM_MODEL_CONFIG_PATH):
raise FileNotFoundError(f"SAM config file not found at {SAM_MODEL_CONFIG_PATH}")
if not os.path.exists(SAM_CHECKPOINT_PATH):
raise FileNotFoundError(f"SAM checkpoint file not found at {SAM_CHECKPOINT_PATH}")
print(f" Config Path: {SAM_MODEL_CONFIG_PATH}")
print(f" Checkpoint Path: {SAM_CHECKPOINT_PATH}")
print(f" Device: {DEVICE}")
# 2. Load config directly
cfg = OmegaConf.load(SAM_MODEL_CONFIG_PATH)
OmegaConf.resolve(cfg) # Resolve any interpolations
# 3. Instantiate model from config
print(" Instantiating SAM model from config...")
sam_model = instantiate(cfg.model, _recursive_=True)
print(" SAM model instantiated.")
# 4. Load checkpoint weights using imported function
print(" Loading checkpoint weights...")
_load_checkpoint(sam_model, SAM_CHECKPOINT_PATH)
print(" Checkpoint loaded.")
# 5. Set device and eval mode
sam_model.to(DEVICE)
sam_model.eval()
print("SAM model moved to device and set to eval mode.")
# 6. Configure the automatic mask generator
print(" Configuring SamAutomaticMaskGenerator...")
mask_generator = automatic_mask_generator.SAM2AutomaticMaskGenerator(
model=sam_model,
points_per_side=32,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
min_mask_region_area=100,
output_mode="binary_mask"
)
print("SAM model and Mask Generator loaded successfully.")
except FileNotFoundError as e:
print(f"ERROR loading SAM (File Not Found): {e}")
# Use st.error for user visibility in Streamlit
st.error(f"SAM Model/Config File Not Found: {e}")
sam_model = None
mask_generator = None
except ImportError as e:
print(f"ERROR: A dependency (like OmegaConf) or sam2 import failed: {e}")
st.error(f"Import Error Loading SAM: {e}")
sam_model = None
mask_generator = None
except Exception as e:
print(f"ERROR loading SAM model manually: {e}")
import traceback
traceback.print_exc()
st.error(f"General Error Loading SAM Model: {e}")
sam_model = None
mask_generator = None
# --- End SAM Model Loading ---
# --- Load Classification Model ---
print("Loading Classification model...")
if os.path.exists(CLASSIFICATION_MODEL_PATH):
try:
classification_model = models.mobilenet_v3_small(weights=None)
classification_model.classifier[3] = torch.nn.Linear(classification_model.classifier[3].in_features, NUM_CLASSES)
checkpoint = torch.load(CLASSIFICATION_MODEL_PATH, map_location=DEVICE, weights_only=False)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
print(" Extracted state_dict from 'state_dict' key.")
elif isinstance(checkpoint, dict) and 'model' in checkpoint:
state_dict = checkpoint['model']
print(" Extracted state_dict from 'model' key.")
elif isinstance(checkpoint, dict):
state_dict = checkpoint
print(" Using loaded checkpoint dictionary as state_dict.")
else:
state_dict = checkpoint
print(" Using loaded object directly as state_dict.")
if not isinstance(state_dict, dict):
raise TypeError(f"Could not extract a state dictionary (dict) from checkpoint. Got type: {type(state_dict)}")
cleaned_state_dict = {}
prefix_to_remove = "backbone."
needs_cleaning = any(key.startswith(prefix_to_remove) for key in state_dict.keys())
if needs_cleaning:
print(f" Cleaning keys: Removing '{prefix_to_remove}' prefix.")
for k, v in state_dict.items():
if k.startswith(prefix_to_remove):
cleaned_state_dict[k[len(prefix_to_remove):]] = v
else:
cleaned_state_dict[k] = v
else:
print(" No 'backbone.' prefix found, using state_dict keys as is.")
cleaned_state_dict = state_dict
report = classification_model.load_state_dict(cleaned_state_dict, strict=True)
print(f" Classification model load report - Missing: {report.missing_keys}, Unexpected: {report.unexpected_keys}")
if report.missing_keys or report.unexpected_keys:
st.warning(f"Classification Model Issues - Missing: {report.missing_keys}, Unexpected: {report.unexpected_keys}")
# Decide if this is fatal - maybe proceed anyway?
# For now, we'll proceed but warn the user.
classification_model.to(DEVICE)
classification_model.eval()
print("Classification model loaded and ready.")
except Exception as e:
print(f"Error loading classification model: {e}")
import traceback
traceback.print_exc()
st.error(f"Error Loading Classification Model: {e}")
classification_model = None
else:
print(f"Classification model not found at {CLASSIFICATION_MODEL_PATH}")
st.error(f"Classification Model Not Found at {CLASSIFICATION_MODEL_PATH}")
classification_model = None
# --- End Classification Model Loading ---
print("--- Model Loading Complete Check ---")
if sam_model is None or mask_generator is None or classification_model is None:
# Error messages were already shown above
print(" ERROR: One or more models failed to load properly.")
return None, None, None # Return None tuple if loading failed
print("--- Model Loading Fully Succeeded ---")
return sam_model, mask_generator, classification_model
# --- Preprocessing Definition (Identical) ---
preprocess_transform = T.Compose([
T.Resize(256, interpolation=T.InterpolationMode.BILINEAR),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- Video Processing Function (Streamlit Adaptation) ---
def process_video_streamlit(video_path, sam_model, mask_generator, classification_model):
# Remove placeholder logic
# st.write(f"Processing video: {video_path}")
# time.sleep(2) # Simulate work
# st.success("Placeholder: Video processed!")
# ... placeholder results ...
# --- Start Actual Logic (Adapted from Gradio app.py) ---
segment_results = [] # List of (PIL_Image, Label_String)
annotated_full_frames = [] # List of PIL_Image
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
st.error("Error opening video file.")
return [], []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
st.write(f"Input video ({os.path.basename(video_path)}): {frame_count} frames @ {fps:.2f} FPS")
# --- Filtering Parameters (Same as before) ---
process_every_n_frames = 100
min_aspect_ratio = 0.2
max_aspect_ratio = 5.0
lower_leaf_hsv = np.array([15, 40, 40])
upper_leaf_hsv = np.array([80, 255, 230])
min_leaf_color_ratio = 0.15
min_laplacian_variance = 150.0
nms_iou_threshold = 0.5
print(f"Processing every {process_every_n_frames} frames.")
print(f"Filtering masks with aspect ratio outside ({min_aspect_ratio}, {max_aspect_ratio}).")
print(f"Filtering segments with less than {min_leaf_color_ratio*100:.0f}% leaf-like pixels.")
print(f"Filtering segments with Laplacian variance < {min_laplacian_variance:.1f}.")
print(f"Applying NMS with IoU threshold: {nms_iou_threshold}")
# --- Streamlit Progress Bar ---
progress_bar = st.progress(0, text="Starting video processing...")
processed_frame_count_for_display = 0
for frame_idx in range(frame_count):
ret, frame = cap.read()
if not ret:
break
# Update progress bar
progress_text = f"Processing frame {frame_idx + 1}/{frame_count}"
progress_bar.progress( (frame_idx + 1) / frame_count, text=progress_text)
# --- Apply Frame Sampling ---
if frame_idx % process_every_n_frames != 0:
continue
# --- Process this sampled frame ---
processed_frame_count_for_display += 1
start_time = time.time()
print(f"\nProcessing sampled frame index {frame_idx} (Display Frame {processed_frame_count_for_display})...")
# --- SAM Mask Generation ---
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
initial_masks = mask_generator.generate(frame_rgb)
print(f" Found {len(initial_masks)} potential masks initially.")
if not initial_masks: continue
# --- NMS ---
boxes_xywh = np.array([ann['bbox'] for ann in initial_masks])
scores = np.array([ann['predicted_iou'] for ann in initial_masks])
boxes_xyxy = boxes_xywh.copy()
boxes_xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2]
boxes_xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3]
boxes_tensor = torch.as_tensor(boxes_xyxy, dtype=torch.float32)
scores_tensor = torch.as_tensor(scores, dtype=torch.float32)
keep_indices = nms(boxes_tensor, scores_tensor, nms_iou_threshold)
filtered_masks = [initial_masks[i] for i in keep_indices.tolist()]
print(f" {len(filtered_masks)} masks remaining after NMS.")
if not filtered_masks: continue
# --- Classification & Filtering ---
processed_mask_count = 0
aspect_ratio_passed_count = 0
color_filter_passed_count = 0
sharpness_filter_passed_count = 0
infected_count = 0
annotated_frame_for_output = frame_rgb.copy() # Start with clean frame for annotations
for ann_idx, ann in enumerate(filtered_masks):
processed_mask_count += 1
if 'bbox' in ann and isinstance(ann['bbox'], (list, tuple)) and len(ann['bbox']) == 4:
bbox = ann['bbox']
try:
x_min, y_min, width, height = map(int, bbox)
img_h, img_w, _ = frame_rgb.shape
x_max = min(x_min + width, img_w)
y_max = min(y_min + height, img_h)
x_min = max(x_min, 0)
y_min = max(y_min, 0)
clipped_width = x_max - x_min
clipped_height = y_max - y_min
if clipped_width <= 0 or clipped_height <= 0: continue
# 1. Aspect Ratio Filter
aspect_ratio = clipped_width / clipped_height
if not (min_aspect_ratio < aspect_ratio < max_aspect_ratio): continue
aspect_ratio_passed_count += 1
# 2. Color Filter
cropped_patch_np = frame_rgb[y_min:y_max, x_min:x_max]
if cropped_patch_np.size == 0: continue
segmentation_mask = ann['segmentation']
if segmentation_mask.dtype == bool:
cropped_segmentation_mask_bool = segmentation_mask[y_min:y_max, x_min:x_max]
else:
cropped_segmentation_mask_bool = segmentation_mask[y_min:y_max, x_min:x_max].astype(bool)
if cropped_segmentation_mask_bool.shape[:2] != cropped_patch_np.shape[:2]:
print(f"Warning: Cropped mask shape {cropped_segmentation_mask_bool.shape[:2]} doesn't match patch shape {cropped_patch_np.shape[:2]}. Skipping.")
continue
cropped_patch_hsv = cv2.cvtColor(cropped_patch_np, cv2.COLOR_RGB2HSV)
color_range_mask = cv2.inRange(cropped_patch_hsv, lower_leaf_hsv, upper_leaf_hsv)
cropped_segmentation_mask_uint8 = cropped_segmentation_mask_bool.astype(np.uint8)
pixels_in_segment_and_range = cv2.bitwise_and(color_range_mask, color_range_mask, mask=cropped_segmentation_mask_uint8)
total_pixels_in_segment = np.count_nonzero(cropped_segmentation_mask_uint8)
if total_pixels_in_segment == 0: continue
leaf_pixel_ratio = np.count_nonzero(pixels_in_segment_and_range) / total_pixels_in_segment
if leaf_pixel_ratio < min_leaf_color_ratio: continue
color_filter_passed_count += 1
# 3. Sharpness Filter
gray_crop = cv2.cvtColor(cropped_patch_np, cv2.COLOR_RGB2GRAY)
laplacian_var = cv2.Laplacian(gray_crop, cv2.CV_64F).var()
if laplacian_var < min_laplacian_variance: continue
sharpness_filter_passed_count += 1
# 4. Classification
cropped_patch_pil = Image.fromarray(cropped_patch_np)
input_tensor = preprocess_transform(cropped_patch_pil)
input_batch = input_tensor.unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = classification_model(input_batch)
probabilities = torch.softmax(output[0], dim=0)
confidence, predicted_class_idx = torch.max(probabilities, 0)
predicted_class_name = CLASS_NAMES[predicted_class_idx.item()]
confidence_score = confidence.item()
label = f"{predicted_class_name} ({confidence_score:.2f})"
segment_results.append((cropped_patch_pil, label))
if predicted_class_name == 'Infected' and confidence_score > 0.5:
infected_count += 1
cv2.rectangle(annotated_frame_for_output, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.putText(annotated_frame_for_output, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
except Exception as e:
original_mask_index = keep_indices[ann_idx].item()
print(f" Error processing mask {original_mask_index} (after NMS index {ann_idx}) with bbox {bbox}: {e}")
print(f" Processed {processed_mask_count} masks after NMS.")
print(f" {aspect_ratio_passed_count} passed aspect ratio filter.")
print(f" {color_filter_passed_count} passed color filter.")
print(f" {sharpness_filter_passed_count} passed sharpness filter (considered leaves)." )
if infected_count > 0:
print(f" Detected {infected_count} infected leaf segments in this frame.")
# Add the annotated frame
annotated_full_frames.append(Image.fromarray(annotated_frame_for_output))
end_time = time.time()
print(f" Sampled frame processing time: {end_time - start_time:.2f}s")
# --- ADDED BREAK: Stop after processing the first sampled frame ---
print(f" DEBUG: Breaking loop after processing first sampled frame (index {frame_idx}).")
break
# --- END ADDED BREAK ---
cap.release()
progress_bar.progress(1.0, text="Video processing complete!")
print(f"Finished processing. Returning {len(segment_results)} detected leaf segments and {len(annotated_full_frames)} processed frames.")
except Exception as e:
st.error(f"An error occurred during video processing: {e}")
import traceback
traceback.print_exc()
# Ensure progress bar completes even on error
if 'progress_bar' in locals():
progress_bar.progress(1.0, text="Processing failed!")
return [], [] # Return empty lists on failure
# Handle cases where no segments or frames were processed
if not segment_results:
st.info("No leaf-like segments found or processed after filtering.")
# Optionally return placeholder images or just empty lists
# Returning empty lists for consistency here
return [], []
return segment_results, annotated_full_frames
# --- Streamlit App UI ---
st.set_page_config(layout="wide")
st.title("Red Spider Mite Detection (Streamlit)")
# Load models (will be cached)
sam_model, mask_generator, classification_model = load_models()
st.markdown("Upload a video file OR select an example below to detect red spider mites on leaves.")
uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mov", "avi"])
# --- Example Videos Section ---
st.markdown("**Or use an Example Video:**")
example_video_dir = os.path.join(_APP_DIR, "test_videos")
example_files = []
if os.path.isdir(example_video_dir):
example_files = [f for f in os.listdir(example_video_dir) if f.lower().endswith(('.mp4', '.avi', '.mov'))]
clicked_example = None
if not example_files:
st.warning("No example videos found in ./test_videos directory.")
else:
# Create columns for examples
cols = st.columns(len(example_files))
for i, example_file in enumerate(example_files):
example_full_path = os.path.join(example_video_dir, example_file)
with cols[i]:
# Use nested columns to control width (e.g., 2/3 width for video column)
vid_col, _ = st.columns([2, 1])
with vid_col:
st.markdown(f"**{example_file}**") # Display filename
st.video(example_full_path) # Display the video
if st.button(f"Use Example: {example_file}", key=f"ex_{i}"): # Use unique key for buttons
clicked_example = example_full_path
# --- End Example Videos Section ---
process_button_main = st.button("Detect Mites from Uploaded Video")
# Placeholders for results
results_placeholder = st.empty()
# Determine which action triggered the run
video_path_to_process = None
if clicked_example:
video_path_to_process = clicked_example
print(f"Processing triggered by example button: {video_path_to_process}")
elif process_button_main:
if uploaded_file is not None:
# Save uploaded file temporarily only if main button clicked and file exists
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
video_path_to_process = tmp_file.name
print(f"Processing triggered by main button with uploaded file: {video_path_to_process}")
else:
st.warning("Please upload a video file before clicking 'Detect Mites from Uploaded Video'.")
# --- Run Processing if a path is determined ---
if video_path_to_process and sam_model is not None:
is_temp_file = (process_button_main and uploaded_file is not None) # Flag to know if we need to delete later
with results_placeholder.container():
st.write(f"Processing: {os.path.basename(video_path_to_process)}... Please wait.")
with st.spinner('Analyzing video frames...'):
# Call processing function
segment_results, annotated_full_frames = process_video_streamlit(
video_path_to_process, sam_model, mask_generator, classification_model
)
# Display results (use_container_width already updated)
st.subheader("Detected Leaf Segments (Filtered)")
if segment_results:
num_cols = 6
cols_disp = st.columns(num_cols)
for i, (img, label) in enumerate(segment_results):
with cols_disp[i % num_cols]:
st.image(img, caption=label, use_container_width=True)
else:
st.info("No leaf-like segments found or processed.")
st.subheader("Processed Frames with Infected Detections")
if annotated_full_frames:
num_cols_frames = 2
cols_frames_disp = st.columns(num_cols_frames)
processed_frame_indices = [i for i, f_idx in enumerate(range(0, 10000, 100)) if i < len(annotated_full_frames)] # Crude way to estimate frame#
for i, img in enumerate(annotated_full_frames):
frame_num_approx = (i+1) * 100 # Approximate frame number based on sampling rate
with cols_frames_disp[i % num_cols_frames]:
st.image(img, caption=f"Processed Frame ~{frame_num_approx}", use_container_width=True)
else:
st.info("No frames processed or no infected segments found in frames.")
# Clean up temporary file only if it was created from an upload
if is_temp_file:
try:
os.unlink(video_path_to_process)
print(f"Cleaned up temp file: {video_path_to_process}")
except Exception as e:
st.warning(f"Could not delete temporary file {video_path_to_process}: {e}")
elif (process_button_main or clicked_example) and sam_model is None:
st.error("Models could not be loaded. Cannot process video.")