|
import os |
|
os.environ["GRADIO_SSR_MODE"] = "false" |
|
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
import base64 |
|
import torch |
|
from PIL import Image |
|
import io |
|
import argparse |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
from transformers import SamModel, SamProcessor |
|
from dam import DescribeAnythingModel, disable_torch_init |
|
try: |
|
from spaces import GPU |
|
except ImportError: |
|
print("Spaces not installed, using dummy GPU decorator") |
|
GPU = lambda fn: fn |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) |
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
|
@GPU(duration=75) |
|
def image_to_sam_embedding(base64_image): |
|
try: |
|
|
|
image_bytes = base64.b64decode(base64_image) |
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
inputs = sam_processor(image, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
image_embedding = sam_model.get_image_embeddings(inputs["pixel_values"]) |
|
|
|
|
|
image_embedding = image_embedding.cpu().numpy() |
|
|
|
|
|
embedding_bytes = image_embedding.tobytes() |
|
embedding_base64 = base64.b64encode(embedding_bytes).decode('utf-8') |
|
|
|
return embedding_base64 |
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
raise gr.Error(f"Failed to process image: {str(e)}") |
|
|
|
@GPU(duration=75) |
|
def describe(image_base64: str, mask_base64: str, query: str): |
|
|
|
image_bytes = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64) |
|
img = Image.open(io.BytesIO(image_bytes)) |
|
mask_bytes = base64.b64decode(mask_base64.split(',')[1] if ',' in mask_base64 else mask_base64) |
|
mask = Image.open(io.BytesIO(mask_bytes)) |
|
|
|
|
|
mask = Image.fromarray((np.array(mask.convert('L')) > 0).astype(np.uint8) * 255) |
|
|
|
|
|
description_generator = dam.get_description(img, mask, query, streaming=True) |
|
|
|
|
|
text = "" |
|
for token in description_generator: |
|
text += token |
|
yield text |
|
|
|
@GPU(duration=75) |
|
def describe_without_streaming(image_base64: str, mask_base64: str, query: str): |
|
|
|
image_bytes = base64.b64decode(image_base64.split(',')[1] if ',' in image_base64 else image_base64) |
|
img = Image.open(io.BytesIO(image_bytes)) |
|
mask_bytes = base64.b64decode(mask_base64.split(',')[1] if ',' in mask_base64 else mask_base64) |
|
mask = Image.open(io.BytesIO(mask_bytes)) |
|
|
|
|
|
mask = Image.fromarray((np.array(mask.convert('L')) > 0).astype(np.uint8) * 255) |
|
|
|
|
|
description = dam.get_description(img, mask, query) |
|
|
|
return description |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Describe Anything gradio demo") |
|
parser.add_argument("--model-path", type=str, default="nvidia/DAM-3B", help="Path to the model checkpoint") |
|
parser.add_argument("--prompt-mode", type=str, default="full+focal_crop", help="Prompt mode") |
|
parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode") |
|
parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature") |
|
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
disable_torch_init() |
|
dam = DescribeAnythingModel( |
|
model_path=args.model_path, |
|
conv_mode=args.conv_mode, |
|
prompt_mode=args.prompt_mode, |
|
temperature=args.temperature, |
|
top_p=args.top_p, |
|
num_beams=1, |
|
max_new_tokens=512, |
|
).to(device) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Interface( |
|
fn=image_to_sam_embedding, |
|
inputs=gr.Textbox(label="Image Base64"), |
|
outputs=gr.Textbox(label="Embedding Base64"), |
|
title="Image Embedding Generator", |
|
api_name="image_to_sam_embedding" |
|
) |
|
gr.Interface( |
|
fn=describe, |
|
inputs=[ |
|
gr.Textbox(label="Image Base64"), |
|
gr.Text(label="Mask Base64"), |
|
gr.Text(label="Prompt") |
|
], |
|
outputs=[ |
|
gr.Text(label="Description") |
|
], |
|
title="Mask Description Generator", |
|
api_name="describe" |
|
) |
|
gr.Interface( |
|
fn=describe_without_streaming, |
|
inputs=[ |
|
gr.Textbox(label="Image Base64"), |
|
gr.Text(label="Mask Base64"), |
|
gr.Text(label="Prompt") |
|
], |
|
outputs=[ |
|
gr.Text(label="Description") |
|
], |
|
title="Mask Description Generator (Non-Streaming)", |
|
api_name="describe_without_streaming" |
|
) |
|
|
|
demo._block_thread = demo.block_thread |
|
demo.block_thread = lambda: None |
|
demo.launch() |
|
|
|
for route in demo.app.routes: |
|
if route.path == "/": |
|
demo.app.routes.remove(route) |
|
demo.app.mount("/", StaticFiles(directory="dist", html=True), name="demo") |
|
|
|
demo._block_thread() |
|
|