radpid / app.py
yassonee's picture
Update app.py
81ff8c7 verified
raw
history blame
6.32 kB
import streamlit as st
# Set page config must be the first Streamlit command
st.set_page_config(
page_title="Fracture Detection System",
page_icon="🦴",
layout="wide"
)
import base64
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import io
from threading import Thread
import uvicorn
import json
import numpy as np
from starlette.responses import JSONResponse
# FastAPI app
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models with caching
@st.cache_resource
def load_models():
try:
return {
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"Nandodeomkar": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return None
# Initialize models
models = load_models()
def draw_boxes(image, predictions, threshold=0.6):
"""
Draw bounding boxes on the image for fracture detections
"""
draw = ImageDraw.Draw(image)
filtered_preds = [p for p in predictions if p['score'] >= threshold]
for pred in filtered_preds:
box = pred['box']
label = f"{pred['label']} ({pred['score']:.2%})"
# Draw rectangle
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline="red",
width=2
)
# Draw label
draw.text(
(box['xmin'], box['ymin'] - 10),
label,
fill="red"
)
return image, filtered_preds
def process_image(image, confidence_threshold):
"""
Process an image through all models and return results
"""
try:
# Object detection
detection_preds = models["D3STRON"](image)
result_image = image.copy()
result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence_threshold)
# Save result image
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
result_base64 = base64.b64encode(img_byte_arr).decode()
# Classifications
class_results = {
"Heem2": models["Heem2"](image),
"Nandodeomkar": models["Nandodeomkar"](image)
}
return {
"success": True,
"detections": filtered_detections,
"classifications": class_results,
"image": result_base64
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# FastAPI endpoint
@app.post("/api/predict")
async def predict(request: Request):
try:
# Read JSON request body
body = await request.json()
# Extract base64 image and confidence threshold
image_base64 = body['data'][0]
confidence_threshold = float(body['data'][1])
# Decode base64 image
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
# Process image
result = process_image(image, confidence_threshold)
return JSONResponse(result)
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
# Streamlit interface
def streamlit_interface():
st.title("🦴 Système de Détection de Fractures")
# File uploader
uploaded_file = st.file_uploader(
"Upload X-ray Image",
type=['png', 'jpg', 'jpeg'],
help="Upload an X-ray image for fracture detection"
)
# Confidence threshold slider
confidence = st.slider(
"Confidence Threshold",
min_value=0.0,
max_value=1.0,
value=0.6,
step=0.05,
help="Adjust the confidence threshold for detection"
)
if uploaded_file:
# Display original image
col1, col2 = st.columns(2)
with col1:
st.subheader("Original X-ray")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
if st.button("Analyze"):
with st.spinner('Analyzing image...'):
try:
# Process image
results = process_image(image, confidence)
if results["success"]:
with col2:
st.subheader("Detection Results")
# Display processed image
result_image = Image.open(io.BytesIO(base64.b64decode(results["image"])))
st.image(result_image, use_column_width=True)
# Display detections
st.subheader("Detected Fractures:")
for detection in results["detections"]:
st.write(f"- {detection['label']}: {detection['score']:.2%}")
# Display classifications
st.subheader("Classification Results:")
st.json(results["classifications"])
else:
st.error("Error processing image: " + results.get("error", "Unknown error"))
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
def run_fastapi():
"""Run the FastAPI server"""
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Start FastAPI in a separate thread
api_thread = Thread(target=run_fastapi, daemon=True)
api_thread.start()
# Run Streamlit interface
streamlit_interface()