|
import streamlit as st |
|
|
|
|
|
st.set_page_config( |
|
page_title="Fracture Detection System", |
|
page_icon="🦴", |
|
layout="wide" |
|
) |
|
|
|
import base64 |
|
from fastapi import FastAPI, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from transformers import pipeline |
|
import torch |
|
from PIL import Image, ImageDraw |
|
import io |
|
from threading import Thread |
|
import uvicorn |
|
import json |
|
import numpy as np |
|
from starlette.responses import JSONResponse |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
try: |
|
return { |
|
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), |
|
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), |
|
"Nandodeomkar": pipeline("image-classification", |
|
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") |
|
} |
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |
|
return None |
|
|
|
|
|
models = load_models() |
|
|
|
def draw_boxes(image, predictions, threshold=0.6): |
|
""" |
|
Draw bounding boxes on the image for fracture detections |
|
""" |
|
draw = ImageDraw.Draw(image) |
|
filtered_preds = [p for p in predictions if p['score'] >= threshold] |
|
|
|
for pred in filtered_preds: |
|
box = pred['box'] |
|
label = f"{pred['label']} ({pred['score']:.2%})" |
|
|
|
|
|
draw.rectangle( |
|
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], |
|
outline="red", |
|
width=2 |
|
) |
|
|
|
|
|
draw.text( |
|
(box['xmin'], box['ymin'] - 10), |
|
label, |
|
fill="red" |
|
) |
|
|
|
return image, filtered_preds |
|
|
|
def process_image(image, confidence_threshold): |
|
""" |
|
Process an image through all models and return results |
|
""" |
|
try: |
|
|
|
detection_preds = models["D3STRON"](image) |
|
result_image = image.copy() |
|
result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence_threshold) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
result_image.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
result_base64 = base64.b64encode(img_byte_arr).decode() |
|
|
|
|
|
class_results = { |
|
"Heem2": models["Heem2"](image), |
|
"Nandodeomkar": models["Nandodeomkar"](image) |
|
} |
|
|
|
return { |
|
"success": True, |
|
"detections": filtered_detections, |
|
"classifications": class_results, |
|
"image": result_base64 |
|
} |
|
|
|
except Exception as e: |
|
return { |
|
"success": False, |
|
"error": str(e) |
|
} |
|
|
|
|
|
@app.post("/api/predict") |
|
async def predict(request: Request): |
|
try: |
|
|
|
body = await request.json() |
|
|
|
|
|
image_base64 = body['data'][0] |
|
confidence_threshold = float(body['data'][1]) |
|
|
|
|
|
image_bytes = base64.b64decode(image_base64) |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
result = process_image(image, confidence_threshold) |
|
|
|
return JSONResponse(result) |
|
|
|
except Exception as e: |
|
return JSONResponse({ |
|
"success": False, |
|
"error": str(e) |
|
}, status_code=500) |
|
|
|
|
|
def streamlit_interface(): |
|
st.title("🦴 Système de Détection de Fractures") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Upload X-ray Image", |
|
type=['png', 'jpg', 'jpeg'], |
|
help="Upload an X-ray image for fracture detection" |
|
) |
|
|
|
|
|
confidence = st.slider( |
|
"Confidence Threshold", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.6, |
|
step=0.05, |
|
help="Adjust the confidence threshold for detection" |
|
) |
|
|
|
if uploaded_file: |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Original X-ray") |
|
image = Image.open(uploaded_file) |
|
st.image(image, use_column_width=True) |
|
|
|
if st.button("Analyze"): |
|
with st.spinner('Analyzing image...'): |
|
try: |
|
|
|
results = process_image(image, confidence) |
|
|
|
if results["success"]: |
|
with col2: |
|
st.subheader("Detection Results") |
|
|
|
result_image = Image.open(io.BytesIO(base64.b64decode(results["image"]))) |
|
st.image(result_image, use_column_width=True) |
|
|
|
|
|
st.subheader("Detected Fractures:") |
|
for detection in results["detections"]: |
|
st.write(f"- {detection['label']}: {detection['score']:.2%}") |
|
|
|
|
|
st.subheader("Classification Results:") |
|
st.json(results["classifications"]) |
|
else: |
|
st.error("Error processing image: " + results.get("error", "Unknown error")) |
|
|
|
except Exception as e: |
|
st.error(f"Error during analysis: {str(e)}") |
|
|
|
def run_fastapi(): |
|
"""Run the FastAPI server""" |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
if __name__ == "__main__": |
|
|
|
api_thread = Thread(target=run_fastapi, daemon=True) |
|
api_thread.start() |
|
|
|
|
|
streamlit_interface() |