File size: 6,318 Bytes
a326b94 81ff8c7 c16e85e 3bb1400 88fb5fa 04a7bfd c16e85e 04a7bfd c16e85e af17427 0000f4a 04a7bfd af17427 6cc7ff9 04a7bfd 005d8cf c16e85e 005d8cf 88fb5fa c16e85e 806ecee c16e85e af17427 3bb1400 04a7bfd c16e85e 88fb5fa 04a7bfd 88fb5fa 04a7bfd 88fb5fa c16e85e 88fb5fa 04a7bfd 6cc7ff9 88fb5fa c16e85e 04a7bfd 0000f4a c16e85e 04a7bfd 0b5ea2b 04a7bfd c16e85e 04a7bfd af17427 04a7bfd c16e85e 04a7bfd 0b5ea2b af17427 78d26e0 c16e85e 04a7bfd c16e85e 04a7bfd af17427 0b5ea2b 04a7bfd c16e85e 88fb5fa c16e85e 88fb5fa af17427 c16e85e 0b5ea2b c16e85e 0b5ea2b c16e85e 0b5ea2b c16e85e 0b5ea2b 6ea5ee2 3bb1400 c16e85e 0b5ea2b c16e85e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import streamlit as st
# Set page config must be the first Streamlit command
st.set_page_config(
page_title="Fracture Detection System",
page_icon="🦴",
layout="wide"
)
import base64
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import io
from threading import Thread
import uvicorn
import json
import numpy as np
from starlette.responses import JSONResponse
# FastAPI app
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models with caching
@st.cache_resource
def load_models():
try:
return {
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"Nandodeomkar": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return None
# Initialize models
models = load_models()
def draw_boxes(image, predictions, threshold=0.6):
"""
Draw bounding boxes on the image for fracture detections
"""
draw = ImageDraw.Draw(image)
filtered_preds = [p for p in predictions if p['score'] >= threshold]
for pred in filtered_preds:
box = pred['box']
label = f"{pred['label']} ({pred['score']:.2%})"
# Draw rectangle
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline="red",
width=2
)
# Draw label
draw.text(
(box['xmin'], box['ymin'] - 10),
label,
fill="red"
)
return image, filtered_preds
def process_image(image, confidence_threshold):
"""
Process an image through all models and return results
"""
try:
# Object detection
detection_preds = models["D3STRON"](image)
result_image = image.copy()
result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence_threshold)
# Save result image
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
result_base64 = base64.b64encode(img_byte_arr).decode()
# Classifications
class_results = {
"Heem2": models["Heem2"](image),
"Nandodeomkar": models["Nandodeomkar"](image)
}
return {
"success": True,
"detections": filtered_detections,
"classifications": class_results,
"image": result_base64
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
# FastAPI endpoint
@app.post("/api/predict")
async def predict(request: Request):
try:
# Read JSON request body
body = await request.json()
# Extract base64 image and confidence threshold
image_base64 = body['data'][0]
confidence_threshold = float(body['data'][1])
# Decode base64 image
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
# Process image
result = process_image(image, confidence_threshold)
return JSONResponse(result)
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
# Streamlit interface
def streamlit_interface():
st.title("🦴 Système de Détection de Fractures")
# File uploader
uploaded_file = st.file_uploader(
"Upload X-ray Image",
type=['png', 'jpg', 'jpeg'],
help="Upload an X-ray image for fracture detection"
)
# Confidence threshold slider
confidence = st.slider(
"Confidence Threshold",
min_value=0.0,
max_value=1.0,
value=0.6,
step=0.05,
help="Adjust the confidence threshold for detection"
)
if uploaded_file:
# Display original image
col1, col2 = st.columns(2)
with col1:
st.subheader("Original X-ray")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
if st.button("Analyze"):
with st.spinner('Analyzing image...'):
try:
# Process image
results = process_image(image, confidence)
if results["success"]:
with col2:
st.subheader("Detection Results")
# Display processed image
result_image = Image.open(io.BytesIO(base64.b64decode(results["image"])))
st.image(result_image, use_column_width=True)
# Display detections
st.subheader("Detected Fractures:")
for detection in results["detections"]:
st.write(f"- {detection['label']}: {detection['score']:.2%}")
# Display classifications
st.subheader("Classification Results:")
st.json(results["classifications"])
else:
st.error("Error processing image: " + results.get("error", "Unknown error"))
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
def run_fastapi():
"""Run the FastAPI server"""
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
# Start FastAPI in a separate thread
api_thread = Thread(target=run_fastapi, daemon=True)
api_thread.start()
# Run Streamlit interface
streamlit_interface() |