import streamlit as st import torch import torchvision.models as models import torchvision.transforms as T from PIL import Image, ImageDraw import numpy as np import cv2 # OpenCV for video processing import os import sys import time import tempfile from torchvision.ops import nms # Imports for manual model loading from omegaconf import OmegaConf from hydra.utils import instantiate # Need _load_checkpoint from the local sam2 copy from sam2.build_sam import _load_checkpoint from sam2 import automatic_mask_generator # --- Configuration --- # Assuming paths relative to the app file location _APP_DIR = os.path.dirname(os.path.abspath(__file__)) SAM_CHECKPOINT_PATH = os.path.join(_APP_DIR, "checkpoints", "sam2.1_hiera_large.pt") SAM_MODEL_CONFIG_PATH = os.path.join(_APP_DIR, "sam2", "configs", "sam2.1", "sam2.1_hiera_l.yaml") CLASSIFICATION_MODEL_PATH = os.path.join(_APP_DIR, "models", "best_mobilenet_v3_small.pth") NUM_CLASSES = 2 CLASS_NAMES = ['Not Infected', 'Infected'] DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Model Loading (Cached with Streamlit) --- @st.cache_resource def load_models(): sam_model = None mask_generator = None classification_model = None print("--- Loading Models (cached) ---") # Add print to see when cache is missed # --- Load SAM Model Manually --- print("Loading SAM model manually...") try: # 1. Check paths if not os.path.exists(SAM_MODEL_CONFIG_PATH): raise FileNotFoundError(f"SAM config file not found at {SAM_MODEL_CONFIG_PATH}") if not os.path.exists(SAM_CHECKPOINT_PATH): raise FileNotFoundError(f"SAM checkpoint file not found at {SAM_CHECKPOINT_PATH}") print(f" Config Path: {SAM_MODEL_CONFIG_PATH}") print(f" Checkpoint Path: {SAM_CHECKPOINT_PATH}") print(f" Device: {DEVICE}") # 2. Load config directly cfg = OmegaConf.load(SAM_MODEL_CONFIG_PATH) OmegaConf.resolve(cfg) # Resolve any interpolations # 3. Instantiate model from config print(" Instantiating SAM model from config...") sam_model = instantiate(cfg.model, _recursive_=True) print(" SAM model instantiated.") # 4. Load checkpoint weights using imported function print(" Loading checkpoint weights...") _load_checkpoint(sam_model, SAM_CHECKPOINT_PATH) print(" Checkpoint loaded.") # 5. Set device and eval mode sam_model.to(DEVICE) sam_model.eval() print("SAM model moved to device and set to eval mode.") # 6. Configure the automatic mask generator print(" Configuring SamAutomaticMaskGenerator...") mask_generator = automatic_mask_generator.SAM2AutomaticMaskGenerator( model=sam_model, points_per_side=32, pred_iou_thresh=0.88, stability_score_thresh=0.95, min_mask_region_area=100, output_mode="binary_mask" ) print("SAM model and Mask Generator loaded successfully.") except FileNotFoundError as e: print(f"ERROR loading SAM (File Not Found): {e}") # Use st.error for user visibility in Streamlit st.error(f"SAM Model/Config File Not Found: {e}") sam_model = None mask_generator = None except ImportError as e: print(f"ERROR: A dependency (like OmegaConf) or sam2 import failed: {e}") st.error(f"Import Error Loading SAM: {e}") sam_model = None mask_generator = None except Exception as e: print(f"ERROR loading SAM model manually: {e}") import traceback traceback.print_exc() st.error(f"General Error Loading SAM Model: {e}") sam_model = None mask_generator = None # --- End SAM Model Loading --- # --- Load Classification Model --- print("Loading Classification model...") if os.path.exists(CLASSIFICATION_MODEL_PATH): try: classification_model = models.mobilenet_v3_small(weights=None) classification_model.classifier[3] = torch.nn.Linear(classification_model.classifier[3].in_features, NUM_CLASSES) checkpoint = torch.load(CLASSIFICATION_MODEL_PATH, map_location=DEVICE, weights_only=False) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] print(" Extracted state_dict from 'state_dict' key.") elif isinstance(checkpoint, dict) and 'model' in checkpoint: state_dict = checkpoint['model'] print(" Extracted state_dict from 'model' key.") elif isinstance(checkpoint, dict): state_dict = checkpoint print(" Using loaded checkpoint dictionary as state_dict.") else: state_dict = checkpoint print(" Using loaded object directly as state_dict.") if not isinstance(state_dict, dict): raise TypeError(f"Could not extract a state dictionary (dict) from checkpoint. Got type: {type(state_dict)}") cleaned_state_dict = {} prefix_to_remove = "backbone." needs_cleaning = any(key.startswith(prefix_to_remove) for key in state_dict.keys()) if needs_cleaning: print(f" Cleaning keys: Removing '{prefix_to_remove}' prefix.") for k, v in state_dict.items(): if k.startswith(prefix_to_remove): cleaned_state_dict[k[len(prefix_to_remove):]] = v else: cleaned_state_dict[k] = v else: print(" No 'backbone.' prefix found, using state_dict keys as is.") cleaned_state_dict = state_dict report = classification_model.load_state_dict(cleaned_state_dict, strict=True) print(f" Classification model load report - Missing: {report.missing_keys}, Unexpected: {report.unexpected_keys}") if report.missing_keys or report.unexpected_keys: st.warning(f"Classification Model Issues - Missing: {report.missing_keys}, Unexpected: {report.unexpected_keys}") # Decide if this is fatal - maybe proceed anyway? # For now, we'll proceed but warn the user. classification_model.to(DEVICE) classification_model.eval() print("Classification model loaded and ready.") except Exception as e: print(f"Error loading classification model: {e}") import traceback traceback.print_exc() st.error(f"Error Loading Classification Model: {e}") classification_model = None else: print(f"Classification model not found at {CLASSIFICATION_MODEL_PATH}") st.error(f"Classification Model Not Found at {CLASSIFICATION_MODEL_PATH}") classification_model = None # --- End Classification Model Loading --- print("--- Model Loading Complete Check ---") if sam_model is None or mask_generator is None or classification_model is None: # Error messages were already shown above print(" ERROR: One or more models failed to load properly.") return None, None, None # Return None tuple if loading failed print("--- Model Loading Fully Succeeded ---") return sam_model, mask_generator, classification_model # --- Preprocessing Definition (Identical) --- preprocess_transform = T.Compose([ T.Resize(256, interpolation=T.InterpolationMode.BILINEAR), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --- Video Processing Function (Streamlit Adaptation) --- def process_video_streamlit(video_path, sam_model, mask_generator, classification_model): # Remove placeholder logic # st.write(f"Processing video: {video_path}") # time.sleep(2) # Simulate work # st.success("Placeholder: Video processed!") # ... placeholder results ... # --- Start Actual Logic (Adapted from Gradio app.py) --- segment_results = [] # List of (PIL_Image, Label_String) annotated_full_frames = [] # List of PIL_Image try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): st.error("Error opening video file.") return [], [] frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) st.write(f"Input video ({os.path.basename(video_path)}): {frame_count} frames @ {fps:.2f} FPS") # --- Filtering Parameters (Same as before) --- process_every_n_frames = 100 min_aspect_ratio = 0.2 max_aspect_ratio = 5.0 lower_leaf_hsv = np.array([15, 40, 40]) upper_leaf_hsv = np.array([80, 255, 230]) min_leaf_color_ratio = 0.15 min_laplacian_variance = 150.0 nms_iou_threshold = 0.5 print(f"Processing every {process_every_n_frames} frames.") print(f"Filtering masks with aspect ratio outside ({min_aspect_ratio}, {max_aspect_ratio}).") print(f"Filtering segments with less than {min_leaf_color_ratio*100:.0f}% leaf-like pixels.") print(f"Filtering segments with Laplacian variance < {min_laplacian_variance:.1f}.") print(f"Applying NMS with IoU threshold: {nms_iou_threshold}") # --- Streamlit Progress Bar --- progress_bar = st.progress(0, text="Starting video processing...") processed_frame_count_for_display = 0 for frame_idx in range(frame_count): ret, frame = cap.read() if not ret: break # Update progress bar progress_text = f"Processing frame {frame_idx + 1}/{frame_count}" progress_bar.progress( (frame_idx + 1) / frame_count, text=progress_text) # --- Apply Frame Sampling --- if frame_idx % process_every_n_frames != 0: continue # --- Process this sampled frame --- processed_frame_count_for_display += 1 start_time = time.time() print(f"\nProcessing sampled frame index {frame_idx} (Display Frame {processed_frame_count_for_display})...") # --- SAM Mask Generation --- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) initial_masks = mask_generator.generate(frame_rgb) print(f" Found {len(initial_masks)} potential masks initially.") if not initial_masks: continue # --- NMS --- boxes_xywh = np.array([ann['bbox'] for ann in initial_masks]) scores = np.array([ann['predicted_iou'] for ann in initial_masks]) boxes_xyxy = boxes_xywh.copy() boxes_xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2] boxes_xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3] boxes_tensor = torch.as_tensor(boxes_xyxy, dtype=torch.float32) scores_tensor = torch.as_tensor(scores, dtype=torch.float32) keep_indices = nms(boxes_tensor, scores_tensor, nms_iou_threshold) filtered_masks = [initial_masks[i] for i in keep_indices.tolist()] print(f" {len(filtered_masks)} masks remaining after NMS.") if not filtered_masks: continue # --- Classification & Filtering --- processed_mask_count = 0 aspect_ratio_passed_count = 0 color_filter_passed_count = 0 sharpness_filter_passed_count = 0 infected_count = 0 annotated_frame_for_output = frame_rgb.copy() # Start with clean frame for annotations for ann_idx, ann in enumerate(filtered_masks): processed_mask_count += 1 if 'bbox' in ann and isinstance(ann['bbox'], (list, tuple)) and len(ann['bbox']) == 4: bbox = ann['bbox'] try: x_min, y_min, width, height = map(int, bbox) img_h, img_w, _ = frame_rgb.shape x_max = min(x_min + width, img_w) y_max = min(y_min + height, img_h) x_min = max(x_min, 0) y_min = max(y_min, 0) clipped_width = x_max - x_min clipped_height = y_max - y_min if clipped_width <= 0 or clipped_height <= 0: continue # 1. Aspect Ratio Filter aspect_ratio = clipped_width / clipped_height if not (min_aspect_ratio < aspect_ratio < max_aspect_ratio): continue aspect_ratio_passed_count += 1 # 2. Color Filter cropped_patch_np = frame_rgb[y_min:y_max, x_min:x_max] if cropped_patch_np.size == 0: continue segmentation_mask = ann['segmentation'] if segmentation_mask.dtype == bool: cropped_segmentation_mask_bool = segmentation_mask[y_min:y_max, x_min:x_max] else: cropped_segmentation_mask_bool = segmentation_mask[y_min:y_max, x_min:x_max].astype(bool) if cropped_segmentation_mask_bool.shape[:2] != cropped_patch_np.shape[:2]: print(f"Warning: Cropped mask shape {cropped_segmentation_mask_bool.shape[:2]} doesn't match patch shape {cropped_patch_np.shape[:2]}. Skipping.") continue cropped_patch_hsv = cv2.cvtColor(cropped_patch_np, cv2.COLOR_RGB2HSV) color_range_mask = cv2.inRange(cropped_patch_hsv, lower_leaf_hsv, upper_leaf_hsv) cropped_segmentation_mask_uint8 = cropped_segmentation_mask_bool.astype(np.uint8) pixels_in_segment_and_range = cv2.bitwise_and(color_range_mask, color_range_mask, mask=cropped_segmentation_mask_uint8) total_pixels_in_segment = np.count_nonzero(cropped_segmentation_mask_uint8) if total_pixels_in_segment == 0: continue leaf_pixel_ratio = np.count_nonzero(pixels_in_segment_and_range) / total_pixels_in_segment if leaf_pixel_ratio < min_leaf_color_ratio: continue color_filter_passed_count += 1 # 3. Sharpness Filter gray_crop = cv2.cvtColor(cropped_patch_np, cv2.COLOR_RGB2GRAY) laplacian_var = cv2.Laplacian(gray_crop, cv2.CV_64F).var() if laplacian_var < min_laplacian_variance: continue sharpness_filter_passed_count += 1 # 4. Classification cropped_patch_pil = Image.fromarray(cropped_patch_np) input_tensor = preprocess_transform(cropped_patch_pil) input_batch = input_tensor.unsqueeze(0).to(DEVICE) with torch.no_grad(): output = classification_model(input_batch) probabilities = torch.softmax(output[0], dim=0) confidence, predicted_class_idx = torch.max(probabilities, 0) predicted_class_name = CLASS_NAMES[predicted_class_idx.item()] confidence_score = confidence.item() label = f"{predicted_class_name} ({confidence_score:.2f})" segment_results.append((cropped_patch_pil, label)) if predicted_class_name == 'Infected' and confidence_score > 0.5: infected_count += 1 cv2.rectangle(annotated_frame_for_output, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) cv2.putText(annotated_frame_for_output, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) except Exception as e: original_mask_index = keep_indices[ann_idx].item() print(f" Error processing mask {original_mask_index} (after NMS index {ann_idx}) with bbox {bbox}: {e}") print(f" Processed {processed_mask_count} masks after NMS.") print(f" {aspect_ratio_passed_count} passed aspect ratio filter.") print(f" {color_filter_passed_count} passed color filter.") print(f" {sharpness_filter_passed_count} passed sharpness filter (considered leaves)." ) if infected_count > 0: print(f" Detected {infected_count} infected leaf segments in this frame.") # Add the annotated frame annotated_full_frames.append(Image.fromarray(annotated_frame_for_output)) end_time = time.time() print(f" Sampled frame processing time: {end_time - start_time:.2f}s") # --- ADDED BREAK: Stop after processing the first sampled frame --- print(f" DEBUG: Breaking loop after processing first sampled frame (index {frame_idx}).") break # --- END ADDED BREAK --- cap.release() progress_bar.progress(1.0, text="Video processing complete!") print(f"Finished processing. Returning {len(segment_results)} detected leaf segments and {len(annotated_full_frames)} processed frames.") except Exception as e: st.error(f"An error occurred during video processing: {e}") import traceback traceback.print_exc() # Ensure progress bar completes even on error if 'progress_bar' in locals(): progress_bar.progress(1.0, text="Processing failed!") return [], [] # Return empty lists on failure # Handle cases where no segments or frames were processed if not segment_results: st.info("No leaf-like segments found or processed after filtering.") # Optionally return placeholder images or just empty lists # Returning empty lists for consistency here return [], [] return segment_results, annotated_full_frames # --- Streamlit App UI --- st.set_page_config(layout="wide") st.title("Red Spider Mite Detection (Streamlit)") # Load models (will be cached) sam_model, mask_generator, classification_model = load_models() st.markdown("Upload a video file OR select an example below to detect red spider mites on leaves.") uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mov", "avi"]) # --- Example Videos Section --- st.markdown("**Or use an Example Video:**") example_video_dir = os.path.join(_APP_DIR, "test_videos") example_files = [] if os.path.isdir(example_video_dir): example_files = [f for f in os.listdir(example_video_dir) if f.lower().endswith(('.mp4', '.avi', '.mov'))] clicked_example = None if not example_files: st.warning("No example videos found in ./test_videos directory.") else: # Create columns for examples cols = st.columns(len(example_files)) for i, example_file in enumerate(example_files): example_full_path = os.path.join(example_video_dir, example_file) with cols[i]: # Use nested columns to control width (e.g., 2/3 width for video column) vid_col, _ = st.columns([2, 1]) with vid_col: st.markdown(f"**{example_file}**") # Display filename st.video(example_full_path) # Display the video if st.button(f"Use Example: {example_file}", key=f"ex_{i}"): # Use unique key for buttons clicked_example = example_full_path # --- End Example Videos Section --- process_button_main = st.button("Detect Mites from Uploaded Video") # Placeholders for results results_placeholder = st.empty() # Determine which action triggered the run video_path_to_process = None if clicked_example: video_path_to_process = clicked_example print(f"Processing triggered by example button: {video_path_to_process}") elif process_button_main: if uploaded_file is not None: # Save uploaded file temporarily only if main button clicked and file exists with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: tmp_file.write(uploaded_file.getvalue()) video_path_to_process = tmp_file.name print(f"Processing triggered by main button with uploaded file: {video_path_to_process}") else: st.warning("Please upload a video file before clicking 'Detect Mites from Uploaded Video'.") # --- Run Processing if a path is determined --- if video_path_to_process and sam_model is not None: is_temp_file = (process_button_main and uploaded_file is not None) # Flag to know if we need to delete later with results_placeholder.container(): st.write(f"Processing: {os.path.basename(video_path_to_process)}... Please wait.") with st.spinner('Analyzing video frames...'): # Call processing function segment_results, annotated_full_frames = process_video_streamlit( video_path_to_process, sam_model, mask_generator, classification_model ) # Display results (use_container_width already updated) st.subheader("Detected Leaf Segments (Filtered)") if segment_results: num_cols = 6 cols_disp = st.columns(num_cols) for i, (img, label) in enumerate(segment_results): with cols_disp[i % num_cols]: st.image(img, caption=label, use_container_width=True) else: st.info("No leaf-like segments found or processed.") st.subheader("Processed Frames with Infected Detections") if annotated_full_frames: num_cols_frames = 2 cols_frames_disp = st.columns(num_cols_frames) processed_frame_indices = [i for i, f_idx in enumerate(range(0, 10000, 100)) if i < len(annotated_full_frames)] # Crude way to estimate frame# for i, img in enumerate(annotated_full_frames): frame_num_approx = (i+1) * 100 # Approximate frame number based on sampling rate with cols_frames_disp[i % num_cols_frames]: st.image(img, caption=f"Processed Frame ~{frame_num_approx}", use_container_width=True) else: st.info("No frames processed or no infected segments found in frames.") # Clean up temporary file only if it was created from an upload if is_temp_file: try: os.unlink(video_path_to_process) print(f"Cleaned up temp file: {video_path_to_process}") except Exception as e: st.warning(f"Could not delete temporary file {video_path_to_process}: {e}") elif (process_button_main or clicked_example) and sam_model is None: st.error("Models could not be loaded. Cannot process video.")