Spaces:
Runtime error
Runtime error
File size: 2,212 Bytes
f8eb0cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import streamlit as st
import torch
import numpy as np
import cv2
from PIL import Image
import tempfile
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
@st.cache_resource
def load_models():
seg_model = maskrcnn_resnet50_fpn(pretrained=True)
seg_model.eval()
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return seg_model, caption_model, caption_processor
seg_model, caption_model, caption_processor = load_models()
st.title("πΌοΈ Image Segmentation & Captioning App")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Original Image", use_column_width=True)
img_np = np.array(image)
img_tensor = F.to_tensor(img_np)
with torch.no_grad():
pred = seg_model([img_tensor])[0]
def apply_masks(img, pred, threshold=0.7):
img = img.copy()
for i in range(len(pred["boxes"])):
score = pred["scores"][i].item()
if score < threshold:
continue
mask = pred["masks"][i, 0].mul(255).byte().cpu().numpy()
img[mask > 128] = [0, 255, 0]
return img
masked_img = apply_masks(img_np, pred)
st.image(masked_img, caption="Segmented Image", use_column_width=True)
inputs = caption_processor(images=image, return_tensors="pt")
out = caption_model.generate(**inputs)
caption = caption_processor.decode(out[0], skip_special_tokens=True)
st.markdown(f"**π Caption:** _{caption}_")
result_img = Image.fromarray(masked_img)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
result_img.save(temp_file.name)
with open(temp_file.name, "rb") as f:
st.download_button("π₯ Download Output", f, file_name="output_result.jpg", mime="image/jpeg")
|