Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torchvision.models as models | |
import torchvision.transforms as T | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import cv2 # OpenCV for video processing | |
import os | |
import sys | |
import time | |
import tempfile | |
from torchvision.ops import nms | |
# Imports for manual model loading | |
from omegaconf import OmegaConf | |
from hydra.utils import instantiate | |
# Need _load_checkpoint from the local sam2 copy | |
from sam2.build_sam import _load_checkpoint | |
from sam2 import automatic_mask_generator | |
# --- Configuration --- | |
# Assuming paths relative to the app file location | |
_APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
SAM_CHECKPOINT_PATH = os.path.join(_APP_DIR, "checkpoints", "sam2.1_hiera_large.pt") | |
SAM_MODEL_CONFIG_PATH = os.path.join(_APP_DIR, "sam2", "configs", "sam2.1", "sam2.1_hiera_l.yaml") | |
CLASSIFICATION_MODEL_PATH = os.path.join(_APP_DIR, "models", "best_mobilenet_v3_small.pth") | |
NUM_CLASSES = 2 | |
CLASS_NAMES = ['Not Infected', 'Infected'] | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# --- Model Loading (Cached with Streamlit) --- | |
def load_models(): | |
sam_model = None | |
mask_generator = None | |
classification_model = None | |
print("--- Loading Models (cached) ---") # Add print to see when cache is missed | |
# --- Load SAM Model Manually --- | |
print("Loading SAM model manually...") | |
try: | |
# 1. Check paths | |
if not os.path.exists(SAM_MODEL_CONFIG_PATH): | |
raise FileNotFoundError(f"SAM config file not found at {SAM_MODEL_CONFIG_PATH}") | |
if not os.path.exists(SAM_CHECKPOINT_PATH): | |
raise FileNotFoundError(f"SAM checkpoint file not found at {SAM_CHECKPOINT_PATH}") | |
print(f" Config Path: {SAM_MODEL_CONFIG_PATH}") | |
print(f" Checkpoint Path: {SAM_CHECKPOINT_PATH}") | |
print(f" Device: {DEVICE}") | |
# 2. Load config directly | |
cfg = OmegaConf.load(SAM_MODEL_CONFIG_PATH) | |
OmegaConf.resolve(cfg) # Resolve any interpolations | |
# 3. Instantiate model from config | |
print(" Instantiating SAM model from config...") | |
sam_model = instantiate(cfg.model, _recursive_=True) | |
print(" SAM model instantiated.") | |
# 4. Load checkpoint weights using imported function | |
print(" Loading checkpoint weights...") | |
_load_checkpoint(sam_model, SAM_CHECKPOINT_PATH) | |
print(" Checkpoint loaded.") | |
# 5. Set device and eval mode | |
sam_model.to(DEVICE) | |
sam_model.eval() | |
print("SAM model moved to device and set to eval mode.") | |
# 6. Configure the automatic mask generator | |
print(" Configuring SamAutomaticMaskGenerator...") | |
mask_generator = automatic_mask_generator.SAM2AutomaticMaskGenerator( | |
model=sam_model, | |
points_per_side=32, | |
pred_iou_thresh=0.88, | |
stability_score_thresh=0.95, | |
min_mask_region_area=100, | |
output_mode="binary_mask" | |
) | |
print("SAM model and Mask Generator loaded successfully.") | |
except FileNotFoundError as e: | |
print(f"ERROR loading SAM (File Not Found): {e}") | |
# Use st.error for user visibility in Streamlit | |
st.error(f"SAM Model/Config File Not Found: {e}") | |
sam_model = None | |
mask_generator = None | |
except ImportError as e: | |
print(f"ERROR: A dependency (like OmegaConf) or sam2 import failed: {e}") | |
st.error(f"Import Error Loading SAM: {e}") | |
sam_model = None | |
mask_generator = None | |
except Exception as e: | |
print(f"ERROR loading SAM model manually: {e}") | |
import traceback | |
traceback.print_exc() | |
st.error(f"General Error Loading SAM Model: {e}") | |
sam_model = None | |
mask_generator = None | |
# --- End SAM Model Loading --- | |
# --- Load Classification Model --- | |
print("Loading Classification model...") | |
if os.path.exists(CLASSIFICATION_MODEL_PATH): | |
try: | |
classification_model = models.mobilenet_v3_small(weights=None) | |
classification_model.classifier[3] = torch.nn.Linear(classification_model.classifier[3].in_features, NUM_CLASSES) | |
checkpoint = torch.load(CLASSIFICATION_MODEL_PATH, map_location=DEVICE, weights_only=False) | |
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
state_dict = checkpoint['state_dict'] | |
print(" Extracted state_dict from 'state_dict' key.") | |
elif isinstance(checkpoint, dict) and 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
print(" Extracted state_dict from 'model' key.") | |
elif isinstance(checkpoint, dict): | |
state_dict = checkpoint | |
print(" Using loaded checkpoint dictionary as state_dict.") | |
else: | |
state_dict = checkpoint | |
print(" Using loaded object directly as state_dict.") | |
if not isinstance(state_dict, dict): | |
raise TypeError(f"Could not extract a state dictionary (dict) from checkpoint. Got type: {type(state_dict)}") | |
cleaned_state_dict = {} | |
prefix_to_remove = "backbone." | |
needs_cleaning = any(key.startswith(prefix_to_remove) for key in state_dict.keys()) | |
if needs_cleaning: | |
print(f" Cleaning keys: Removing '{prefix_to_remove}' prefix.") | |
for k, v in state_dict.items(): | |
if k.startswith(prefix_to_remove): | |
cleaned_state_dict[k[len(prefix_to_remove):]] = v | |
else: | |
cleaned_state_dict[k] = v | |
else: | |
print(" No 'backbone.' prefix found, using state_dict keys as is.") | |
cleaned_state_dict = state_dict | |
report = classification_model.load_state_dict(cleaned_state_dict, strict=True) | |
print(f" Classification model load report - Missing: {report.missing_keys}, Unexpected: {report.unexpected_keys}") | |
if report.missing_keys or report.unexpected_keys: | |
st.warning(f"Classification Model Issues - Missing: {report.missing_keys}, Unexpected: {report.unexpected_keys}") | |
# Decide if this is fatal - maybe proceed anyway? | |
# For now, we'll proceed but warn the user. | |
classification_model.to(DEVICE) | |
classification_model.eval() | |
print("Classification model loaded and ready.") | |
except Exception as e: | |
print(f"Error loading classification model: {e}") | |
import traceback | |
traceback.print_exc() | |
st.error(f"Error Loading Classification Model: {e}") | |
classification_model = None | |
else: | |
print(f"Classification model not found at {CLASSIFICATION_MODEL_PATH}") | |
st.error(f"Classification Model Not Found at {CLASSIFICATION_MODEL_PATH}") | |
classification_model = None | |
# --- End Classification Model Loading --- | |
print("--- Model Loading Complete Check ---") | |
if sam_model is None or mask_generator is None or classification_model is None: | |
# Error messages were already shown above | |
print(" ERROR: One or more models failed to load properly.") | |
return None, None, None # Return None tuple if loading failed | |
print("--- Model Loading Fully Succeeded ---") | |
return sam_model, mask_generator, classification_model | |
# --- Preprocessing Definition (Identical) --- | |
preprocess_transform = T.Compose([ | |
T.Resize(256, interpolation=T.InterpolationMode.BILINEAR), | |
T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# --- Video Processing Function (Streamlit Adaptation) --- | |
def process_video_streamlit(video_path, sam_model, mask_generator, classification_model): | |
# Remove placeholder logic | |
# st.write(f"Processing video: {video_path}") | |
# time.sleep(2) # Simulate work | |
# st.success("Placeholder: Video processed!") | |
# ... placeholder results ... | |
# --- Start Actual Logic (Adapted from Gradio app.py) --- | |
segment_results = [] # List of (PIL_Image, Label_String) | |
annotated_full_frames = [] # List of PIL_Image | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
st.error("Error opening video file.") | |
return [], [] | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
st.write(f"Input video ({os.path.basename(video_path)}): {frame_count} frames @ {fps:.2f} FPS") | |
# --- Filtering Parameters (Same as before) --- | |
process_every_n_frames = 100 | |
min_aspect_ratio = 0.2 | |
max_aspect_ratio = 5.0 | |
lower_leaf_hsv = np.array([15, 40, 40]) | |
upper_leaf_hsv = np.array([80, 255, 230]) | |
min_leaf_color_ratio = 0.15 | |
min_laplacian_variance = 150.0 | |
nms_iou_threshold = 0.5 | |
print(f"Processing every {process_every_n_frames} frames.") | |
print(f"Filtering masks with aspect ratio outside ({min_aspect_ratio}, {max_aspect_ratio}).") | |
print(f"Filtering segments with less than {min_leaf_color_ratio*100:.0f}% leaf-like pixels.") | |
print(f"Filtering segments with Laplacian variance < {min_laplacian_variance:.1f}.") | |
print(f"Applying NMS with IoU threshold: {nms_iou_threshold}") | |
# --- Streamlit Progress Bar --- | |
progress_bar = st.progress(0, text="Starting video processing...") | |
processed_frame_count_for_display = 0 | |
for frame_idx in range(frame_count): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Update progress bar | |
progress_text = f"Processing frame {frame_idx + 1}/{frame_count}" | |
progress_bar.progress( (frame_idx + 1) / frame_count, text=progress_text) | |
# --- Apply Frame Sampling --- | |
if frame_idx % process_every_n_frames != 0: | |
continue | |
# --- Process this sampled frame --- | |
processed_frame_count_for_display += 1 | |
start_time = time.time() | |
print(f"\nProcessing sampled frame index {frame_idx} (Display Frame {processed_frame_count_for_display})...") | |
# --- SAM Mask Generation --- | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
initial_masks = mask_generator.generate(frame_rgb) | |
print(f" Found {len(initial_masks)} potential masks initially.") | |
if not initial_masks: continue | |
# --- NMS --- | |
boxes_xywh = np.array([ann['bbox'] for ann in initial_masks]) | |
scores = np.array([ann['predicted_iou'] for ann in initial_masks]) | |
boxes_xyxy = boxes_xywh.copy() | |
boxes_xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2] | |
boxes_xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3] | |
boxes_tensor = torch.as_tensor(boxes_xyxy, dtype=torch.float32) | |
scores_tensor = torch.as_tensor(scores, dtype=torch.float32) | |
keep_indices = nms(boxes_tensor, scores_tensor, nms_iou_threshold) | |
filtered_masks = [initial_masks[i] for i in keep_indices.tolist()] | |
print(f" {len(filtered_masks)} masks remaining after NMS.") | |
if not filtered_masks: continue | |
# --- Classification & Filtering --- | |
processed_mask_count = 0 | |
aspect_ratio_passed_count = 0 | |
color_filter_passed_count = 0 | |
sharpness_filter_passed_count = 0 | |
infected_count = 0 | |
annotated_frame_for_output = frame_rgb.copy() # Start with clean frame for annotations | |
for ann_idx, ann in enumerate(filtered_masks): | |
processed_mask_count += 1 | |
if 'bbox' in ann and isinstance(ann['bbox'], (list, tuple)) and len(ann['bbox']) == 4: | |
bbox = ann['bbox'] | |
try: | |
x_min, y_min, width, height = map(int, bbox) | |
img_h, img_w, _ = frame_rgb.shape | |
x_max = min(x_min + width, img_w) | |
y_max = min(y_min + height, img_h) | |
x_min = max(x_min, 0) | |
y_min = max(y_min, 0) | |
clipped_width = x_max - x_min | |
clipped_height = y_max - y_min | |
if clipped_width <= 0 or clipped_height <= 0: continue | |
# 1. Aspect Ratio Filter | |
aspect_ratio = clipped_width / clipped_height | |
if not (min_aspect_ratio < aspect_ratio < max_aspect_ratio): continue | |
aspect_ratio_passed_count += 1 | |
# 2. Color Filter | |
cropped_patch_np = frame_rgb[y_min:y_max, x_min:x_max] | |
if cropped_patch_np.size == 0: continue | |
segmentation_mask = ann['segmentation'] | |
if segmentation_mask.dtype == bool: | |
cropped_segmentation_mask_bool = segmentation_mask[y_min:y_max, x_min:x_max] | |
else: | |
cropped_segmentation_mask_bool = segmentation_mask[y_min:y_max, x_min:x_max].astype(bool) | |
if cropped_segmentation_mask_bool.shape[:2] != cropped_patch_np.shape[:2]: | |
print(f"Warning: Cropped mask shape {cropped_segmentation_mask_bool.shape[:2]} doesn't match patch shape {cropped_patch_np.shape[:2]}. Skipping.") | |
continue | |
cropped_patch_hsv = cv2.cvtColor(cropped_patch_np, cv2.COLOR_RGB2HSV) | |
color_range_mask = cv2.inRange(cropped_patch_hsv, lower_leaf_hsv, upper_leaf_hsv) | |
cropped_segmentation_mask_uint8 = cropped_segmentation_mask_bool.astype(np.uint8) | |
pixels_in_segment_and_range = cv2.bitwise_and(color_range_mask, color_range_mask, mask=cropped_segmentation_mask_uint8) | |
total_pixels_in_segment = np.count_nonzero(cropped_segmentation_mask_uint8) | |
if total_pixels_in_segment == 0: continue | |
leaf_pixel_ratio = np.count_nonzero(pixels_in_segment_and_range) / total_pixels_in_segment | |
if leaf_pixel_ratio < min_leaf_color_ratio: continue | |
color_filter_passed_count += 1 | |
# 3. Sharpness Filter | |
gray_crop = cv2.cvtColor(cropped_patch_np, cv2.COLOR_RGB2GRAY) | |
laplacian_var = cv2.Laplacian(gray_crop, cv2.CV_64F).var() | |
if laplacian_var < min_laplacian_variance: continue | |
sharpness_filter_passed_count += 1 | |
# 4. Classification | |
cropped_patch_pil = Image.fromarray(cropped_patch_np) | |
input_tensor = preprocess_transform(cropped_patch_pil) | |
input_batch = input_tensor.unsqueeze(0).to(DEVICE) | |
with torch.no_grad(): | |
output = classification_model(input_batch) | |
probabilities = torch.softmax(output[0], dim=0) | |
confidence, predicted_class_idx = torch.max(probabilities, 0) | |
predicted_class_name = CLASS_NAMES[predicted_class_idx.item()] | |
confidence_score = confidence.item() | |
label = f"{predicted_class_name} ({confidence_score:.2f})" | |
segment_results.append((cropped_patch_pil, label)) | |
if predicted_class_name == 'Infected' and confidence_score > 0.5: | |
infected_count += 1 | |
cv2.rectangle(annotated_frame_for_output, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) | |
cv2.putText(annotated_frame_for_output, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) | |
except Exception as e: | |
original_mask_index = keep_indices[ann_idx].item() | |
print(f" Error processing mask {original_mask_index} (after NMS index {ann_idx}) with bbox {bbox}: {e}") | |
print(f" Processed {processed_mask_count} masks after NMS.") | |
print(f" {aspect_ratio_passed_count} passed aspect ratio filter.") | |
print(f" {color_filter_passed_count} passed color filter.") | |
print(f" {sharpness_filter_passed_count} passed sharpness filter (considered leaves)." ) | |
if infected_count > 0: | |
print(f" Detected {infected_count} infected leaf segments in this frame.") | |
# Add the annotated frame | |
annotated_full_frames.append(Image.fromarray(annotated_frame_for_output)) | |
end_time = time.time() | |
print(f" Sampled frame processing time: {end_time - start_time:.2f}s") | |
# --- ADDED BREAK: Stop after processing the first sampled frame --- | |
print(f" DEBUG: Breaking loop after processing first sampled frame (index {frame_idx}).") | |
break | |
# --- END ADDED BREAK --- | |
cap.release() | |
progress_bar.progress(1.0, text="Video processing complete!") | |
print(f"Finished processing. Returning {len(segment_results)} detected leaf segments and {len(annotated_full_frames)} processed frames.") | |
except Exception as e: | |
st.error(f"An error occurred during video processing: {e}") | |
import traceback | |
traceback.print_exc() | |
# Ensure progress bar completes even on error | |
if 'progress_bar' in locals(): | |
progress_bar.progress(1.0, text="Processing failed!") | |
return [], [] # Return empty lists on failure | |
# Handle cases where no segments or frames were processed | |
if not segment_results: | |
st.info("No leaf-like segments found or processed after filtering.") | |
# Optionally return placeholder images or just empty lists | |
# Returning empty lists for consistency here | |
return [], [] | |
return segment_results, annotated_full_frames | |
# --- Streamlit App UI --- | |
st.set_page_config(layout="wide") | |
st.title("Red Spider Mite Detection (Streamlit)") | |
# Load models (will be cached) | |
sam_model, mask_generator, classification_model = load_models() | |
st.markdown("Upload a video file OR select an example below to detect red spider mites on leaves.") | |
uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mov", "avi"]) | |
# --- Example Videos Section --- | |
st.markdown("**Or use an Example Video:**") | |
example_video_dir = os.path.join(_APP_DIR, "test_videos") | |
example_files = [] | |
if os.path.isdir(example_video_dir): | |
example_files = [f for f in os.listdir(example_video_dir) if f.lower().endswith(('.mp4', '.avi', '.mov'))] | |
clicked_example = None | |
if not example_files: | |
st.warning("No example videos found in ./test_videos directory.") | |
else: | |
# Create columns for examples | |
cols = st.columns(len(example_files)) | |
for i, example_file in enumerate(example_files): | |
example_full_path = os.path.join(example_video_dir, example_file) | |
with cols[i]: | |
# Use nested columns to control width (e.g., 2/3 width for video column) | |
vid_col, _ = st.columns([2, 1]) | |
with vid_col: | |
st.markdown(f"**{example_file}**") # Display filename | |
st.video(example_full_path) # Display the video | |
if st.button(f"Use Example: {example_file}", key=f"ex_{i}"): # Use unique key for buttons | |
clicked_example = example_full_path | |
# --- End Example Videos Section --- | |
process_button_main = st.button("Detect Mites from Uploaded Video") | |
# Placeholders for results | |
results_placeholder = st.empty() | |
# Determine which action triggered the run | |
video_path_to_process = None | |
if clicked_example: | |
video_path_to_process = clicked_example | |
print(f"Processing triggered by example button: {video_path_to_process}") | |
elif process_button_main: | |
if uploaded_file is not None: | |
# Save uploaded file temporarily only if main button clicked and file exists | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
video_path_to_process = tmp_file.name | |
print(f"Processing triggered by main button with uploaded file: {video_path_to_process}") | |
else: | |
st.warning("Please upload a video file before clicking 'Detect Mites from Uploaded Video'.") | |
# --- Run Processing if a path is determined --- | |
if video_path_to_process and sam_model is not None: | |
is_temp_file = (process_button_main and uploaded_file is not None) # Flag to know if we need to delete later | |
with results_placeholder.container(): | |
st.write(f"Processing: {os.path.basename(video_path_to_process)}... Please wait.") | |
with st.spinner('Analyzing video frames...'): | |
# Call processing function | |
segment_results, annotated_full_frames = process_video_streamlit( | |
video_path_to_process, sam_model, mask_generator, classification_model | |
) | |
# Display results (use_container_width already updated) | |
st.subheader("Detected Leaf Segments (Filtered)") | |
if segment_results: | |
num_cols = 6 | |
cols_disp = st.columns(num_cols) | |
for i, (img, label) in enumerate(segment_results): | |
with cols_disp[i % num_cols]: | |
st.image(img, caption=label, use_container_width=True) | |
else: | |
st.info("No leaf-like segments found or processed.") | |
st.subheader("Processed Frames with Infected Detections") | |
if annotated_full_frames: | |
num_cols_frames = 2 | |
cols_frames_disp = st.columns(num_cols_frames) | |
processed_frame_indices = [i for i, f_idx in enumerate(range(0, 10000, 100)) if i < len(annotated_full_frames)] # Crude way to estimate frame# | |
for i, img in enumerate(annotated_full_frames): | |
frame_num_approx = (i+1) * 100 # Approximate frame number based on sampling rate | |
with cols_frames_disp[i % num_cols_frames]: | |
st.image(img, caption=f"Processed Frame ~{frame_num_approx}", use_container_width=True) | |
else: | |
st.info("No frames processed or no infected segments found in frames.") | |
# Clean up temporary file only if it was created from an upload | |
if is_temp_file: | |
try: | |
os.unlink(video_path_to_process) | |
print(f"Cleaned up temp file: {video_path_to_process}") | |
except Exception as e: | |
st.warning(f"Could not delete temporary file {video_path_to_process}: {e}") | |
elif (process_button_main or clicked_example) and sam_model is None: | |
st.error("Models could not be loaded. Cannot process video.") |