rootglitch's picture
Removed blip
5017c3c
raw
history blame
22.9 kB
import os
import sys
import warnings
import random
import time
import logging
from typing import Dict, List, Tuple, Union, Optional
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Download model weights only if they don't exist
if not os.path.exists("groundingdino_swint_ogc.pth"):
os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
if not os.path.exists("sam_hq_vit_l.pth"):
os.system("wget https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth")
# Add paths
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
sys.path.append(os.path.join(os.getcwd(), "sam-hq"))
warnings.filterwarnings("ignore")
import numpy as np
import torch
import torchvision
import gradio as gr
import argparse
from PIL import Image, ImageDraw, ImageFont
from scipy import ndimage
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from segment_anything import build_sam_vit_l, SamPredictor
# # BLIP
# from transformers import BlipProcessor, BlipForConditionalGeneration
# Constants
CONFIG_FILE = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
OUTPUT_DIR = "outputs"
# Global variables for model caching
_models = {
'groundingdino': None,
'sam_predictor': None
}
# Enable GPU if available with proper error handling
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {device}")
except Exception as e:
logger.warning(f"Error detecting GPU, falling back to CPU: {e}")
device = 'cpu'
class ModelManager:
"""Manages model loading, unloading, and provides error handling"""
@staticmethod
def load_model(model_name: str) -> None:
"""Load a model if not already loaded"""
try:
if model_name == 'groundingdino' and _models['groundingdino'] is None:
logger.info("Loading GroundingDINO model...")
start_time = time.time()
if not os.path.exists(GROUNDINGDINO_CHECKPOINT):
raise FileNotFoundError(f"GroundingDINO checkpoint not found at {GROUNDINGDINO_CHECKPOINT}")
args = SLConfig.fromfile(CONFIG_FILE)
args.device = device
model = build_model(args)
checkpoint = torch.load(GROUNDINGDINO_CHECKPOINT, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
logger.info(f"GroundingDINO load result: {load_res}")
_ = model.eval()
_models['groundingdino'] = model
logger.info(f"GroundingDINO model loaded in {time.time() - start_time:.2f} seconds")
elif model_name == 'sam' and _models['sam_predictor'] is None:
logger.info("Loading SAM-HQ model...")
start_time = time.time()
if not os.path.exists(SAM_CHECKPOINT):
raise FileNotFoundError(f"SAM checkpoint not found at {SAM_CHECKPOINT}")
sam = build_sam_vit_l(checkpoint=SAM_CHECKPOINT)
sam.to(device=device)
_models['sam_predictor'] = SamPredictor(sam)
logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
# elif model_name == 'blip' and (_models['blip_processor'] is None or _models['blip_model'] is None):
# logger.info("Loading BLIP model...")
# start_time = time.time()
# _models['blip_processor'] = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
# _models['blip_model'] = BlipForConditionalGeneration.from_pretrained(
# "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
# ).to(device)
# logger.info(f"BLIP model loaded in {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Error loading {model_name} model: {e}")
raise RuntimeError(f"Failed to load {model_name} model: {e}")
@staticmethod
def get_model(model_name: str):
"""Get a model, loading it if necessary"""
if model_name not in _models or _models[model_name] is None:
ModelManager.load_model(model_name)
return _models[model_name]
@staticmethod
def unload_model(model_name: str) -> None:
"""Unload a model to free memory"""
if model_name in _models and _models[model_name] is not None:
logger.info(f"Unloading {model_name} model")
_models[model_name] = None
if device == 'cuda':
torch.cuda.empty_cache()
# def generate_caption(raw_image: Image.Image) -> str:
# """Generate image caption using BLIP"""
# try:
# blip_processor = ModelManager.get_model('blip_processor')
# blip_model = ModelManager.get_model('blip_model')
# inputs = blip_processor(raw_image, return_tensors="pt").to(device, torch.float16)
# out = blip_model.generate(**inputs)
# caption = blip_processor.decode(out[0], skip_special_tokens=True)
# logger.info(f"Generated caption: {caption}")
# return caption
# except Exception as e:
# logger.error(f"Error generating caption: {e}")
# return "Failed to generate caption."
def transform_image(image_pil: Image.Image) -> torch.Tensor:
"""Transform PIL image for GroundingDINO"""
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image, _ = transform(image_pil, None) # 3, h, w
return image
def get_grounding_output(
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
with_logits: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
"""Run GroundingDINO to get bounding boxes from text prompt"""
try:
model = ModelManager.get_model('groundingdino')
# Format caption
caption = caption.lower().strip()
if not caption.endswith("."):
caption = caption + "."
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
# Filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
# Get phrases
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(
logit > text_threshold, tokenized, tokenizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
scores.append(logit.max().item())
return boxes_filt, torch.Tensor(scores), pred_phrases
except Exception as e:
logger.error(f"Error in grounding output: {e}")
# Return empty results instead of crashing
return torch.Tensor([]), torch.Tensor([]), []
def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
"""Draw mask on image"""
# if random_color:
# color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
# else:
# color = (30, 144, 255, 153)
color = (255, 255, 255, 255)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> None:
"""Draw bounding box on image"""
color = tuple(np.random.randint(0, 255, size=3).tolist())
draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=2)
if label:
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((box[0], box[1]), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (box[0], box[1], w + box[0], box[1] + h)
draw.rectangle(bbox, fill=color)
draw.text((box[0], box[1]), str(label), fill="white")
# def draw_point(point: np.ndarray, draw: ImageDraw.Draw, r: int = 10) -> None:
# """Draw points on image"""
# for p in point:
# x, y = p
# draw.ellipse((x-r, y-r, x+r, y+r), fill='green')
# def process_scribble_points(scribble: np.ndarray) -> np.ndarray:
# """Process scribble mask to get point coordinates"""
# # Transpose to get the correct orientation
# scribble = scribble.transpose(2, 1, 0)[0]
# # Label connected components
# labeled_array, num_features = ndimage.label(scribble >= 255)
# if num_features == 0:
# logger.warning("No points detected in scribble")
# return np.array([])
# # Get center of mass for each component
# centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features + 1))
# return np.array(centers)
# def process_scribble_box(scribble: np.ndarray) -> torch.Tensor:
# """Process scribble mask to get bounding box"""
# # Get point coordinates first
# centers = process_scribble_points(scribble)
# if len(centers) < 2:
# logger.warning("Not enough points for bounding box, need at least 2")
# # Return a default small box in the center if not enough points
# return torch.tensor([[0.4, 0.4, 0.6, 0.6]])
# # Define bounding box from scribble centers: (x_min, y_min, x_max, y_max)
# x_min = centers[:, 0].min()
# x_max = centers[:, 0].max()
# y_min = centers[:, 1].min()
# y_max = centers[:, 1].max()
# bbox = np.array([x_min, y_min, x_max, y_max])
# return torch.tensor(bbox).unsqueeze(0)
def run_grounded_sam(
input_image
# text_prompt: str,
# task_type: str,
# box_threshold: float,
# text_threshold: float,
# iou_threshold: float,
# hq_token_only
) -> List[Image.Image]:
"""Main function to run GroundingDINO and SAM-HQ"""
try:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
text_prompt = 'car'
task_type = 'text'
box_threshold = 0.3
text_threshold = 0.25
iou_threshold = 0.8
hq_token_only = True
# Process input image
if isinstance(input_image, dict):
# Input from gradio sketch component
scribble = np.array(input_image["mask"])
image_pil = input_image["image"].convert("RGB")
else:
# Direct image input
image_pil = input_image.convert("RGB") if input_image else None
scribble = None
if image_pil is None:
logger.error("No input image provided")
return [Image.new('RGB', (400, 300), color='gray')]
# # Prepare for scribble tasks
# if task_type == 'scribble_box' or task_type == 'scribble_point':
# if scribble is None:
# logger.warning(f"No scribble provided for {task_type} task")
# scribble = np.zeros((image_pil.height, image_pil.width, 3), dtype=np.uint8)
# Transform image for GroundingDINO
transformed_image = transform_image(image_pil)
# Load models as needed
ModelManager.load_model('groundingdino')
size = image_pil.size
H, W = size[1], size[0]
# Run GroundingDINO with provided text
boxes_filt, scores, pred_phrases = get_grounding_output(
transformed_image, text_prompt, box_threshold, text_threshold
)
# # Process based on task type
# if task_type == 'automatic':
# # Generate caption with BLIP
# ModelManager.load_model('blip')
# text_prompt = generate_caption(image_pil)
# logger.info(f"Automatic caption: {text_prompt}")
# # Run GroundingDINO
# boxes_filt, scores, pred_phrases = get_grounding_output(
# transformed_image, text_prompt, box_threshold, text_threshold
# )
# elif task_type == 'text':
# if not text_prompt:
# logger.warning("No text prompt provided for 'text' task")
# return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
# # Run GroundingDINO with provided text
# boxes_filt, scores, pred_phrases = get_grounding_output(
# transformed_image, text_prompt, box_threshold, text_threshold
# )
# elif task_type == 'scribble_box':
# # No need for GroundingDINO, get box from scribble
# boxes_filt = process_scribble_box(scribble)
# scores = torch.ones(boxes_filt.size(0))
# pred_phrases = ["scribble_box"] * boxes_filt.size(0)
# elif task_type == 'scribble_point':
# # Will handle differently with SAM
# point_coords = process_scribble_points(scribble)
# if len(point_coords) == 0:
# logger.warning("No points detected in scribble")
# return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
# boxes_filt = None # Not needed for point-based segmentation
# else:
# logger.error(f"Unknown task type: {task_type}")
# return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
# Process boxes if present (not for scribble_point)
if boxes_filt is not None:
# Scale boxes to image dimensions
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
# Apply non-maximum suppression if we have multiple boxes
if boxes_filt.size(0) > 1:
logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
# Load SAM model
ModelManager.load_model('sam')
sam_predictor = ModelManager.get_model('sam_predictor')
# Set image for SAM
image = np.array(image_pil)
sam_predictor.set_image(image)
# # Convert string to boolean
# if isinstance(hq_token_only, str):
# hq_token_only = (hq_token_only.lower() == 'true')
# Run SAM
# Use boxes for these task types
if boxes_filt.size(0) == 0:
logger.warning("No boxes detected")
return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
hq_token_only=hq_token_only,
)
# elif task_type == 'scribble_point':
# # Use points for this task type
# point_labels = np.ones(point_coords.shape[0])
# masks, _, _ = sam_predictor.predict(
# point_coords=point_coords,
# point_labels=point_labels,
# box=None,
# multimask_output=False,
# hq_token_only=hq_token_only,
# )
# Create mask image
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
# Draw masks
if task_type == 'text':
# for mask in masks:
# draw_mask(mask, mask_draw, random_color=True)
# else:
for mask in masks:
draw_mask(mask[0].cpu().numpy(), mask_draw)
# Draw boxes and points on original image
image_draw = ImageDraw.Draw(image_pil)
for box, label in zip(boxes_filt, pred_phrases):
draw_box(box, image_draw, label)
# if task_type == 'scribble_box':
# for box in boxes_filt:
# draw_box(box, image_draw, None)
# elif task_type in ['text', 'automatic']:
# for box, label in zip(boxes_filt, pred_phrases):
# draw_box(box, image_draw, label)
# elif task_type == 'scribble_point':
# draw_point(point_coords, image_draw)
# Add caption text for automatic mode
# if task_type == 'automatic':
# image_draw.text((10, 10), text_prompt, fill='black')
# Combine original image with mask
# image_pil = image_pil.convert('RGBA')
# image_pil.alpha_composite(mask_image)
# return [image_pil, mask_image]
return [mask_image]
except Exception as e:
logger.error(f"Error in run_grounded_sam: {e}")
# Return original image on error
if isinstance(input_image, dict) and "image" in input_image:
return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
elif isinstance(input_image, Image.Image):
return [input_image, Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))]
else:
return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
def create_ui():
"""Create Gradio UI for CarViz demo"""
with gr.Blocks(title="CarViz Demo") as block:
gr.Markdown("""
# CarViz
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="image")
# input_image = gr.ImageMask(
# sources=["upload", "clipboard"],
# transforms=[],
# layers=False,
# format="pil",
# label="base image",
# show_label=True
# )
# task_type = gr.Dropdown(
# ["automatic", "scribble_point", "scribble_box", "text"],
# value="automatic",
# label="Task Type"
# )
# text_prompt = gr.Textbox(label="Text Prompt", placeholder="bench .")
# hq_token_only = gr.Dropdown(
# [False, True], value=False, label="hq_token_only"
# )
run_button = gr.Button(value='Run')
# with gr.Accordion("Advanced options", open=False):
# box_threshold = gr.Slider(
# label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
# )
# text_threshold = gr.Slider(
# label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
# )
# iou_threshold = gr.Slider(
# label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
# )
with gr.Column():
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
)
# # Update visibility of text prompt based on task type
# def update_text_prompt_visibility(task):
# return gr.update(visible=(task == "text"))
# task_type.change(
# fn=update_text_prompt_visibility,
# inputs=[task_type],
# outputs=[text_prompt]
# )
# Run button
run_button.click(
fn=run_grounded_sam,
inputs=[
input_image
# , text_prompt, task_type,
# box_threshold, text_threshold, iou_threshold, hq_token_only
],
outputs=gallery
)
return block
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
parser.add_argument('--no-gradio-queue', action="store_true", help="disable gradio queue")
parser.add_argument('--port', type=int, default=7860, help="port to run the app")
parser.add_argument('--host', type=str, default="0.0.0.0", help="host to run the app")
args = parser.parse_args()
logger.info(f"Starting CarViz demo with args: {args}")
# Check for model files
if not os.path.exists(GROUNDINGDINO_CHECKPOINT):
logger.warning(f"GroundingDINO checkpoint not found at {GROUNDINGDINO_CHECKPOINT}")
if not os.path.exists(SAM_CHECKPOINT):
logger.warning(f"SAM-HQ checkpoint not found at {SAM_CHECKPOINT}")
# Create app
block = create_ui()
if not args.no_gradio_queue:
block = block.queue()
# Launch app
try:
block.launch(
debug=args.debug,
share=args.share,
show_error=True,
server_name=args.host,
server_port=args.port
)
except Exception as e:
logger.error(f"Error launching app: {e}")