|
import streamlit as st |
|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from starlette.responses import JSONResponse |
|
from transformers import pipeline |
|
import torch |
|
from PIL import Image, ImageDraw |
|
import io |
|
import base64 |
|
import numpy as np |
|
import json |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
title="Fracture Detection API", |
|
description="API for detecting fractures in X-ray images using multiple ML models", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
expose_headers=["*"] |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
logger.info("Loading ML models...") |
|
try: |
|
return { |
|
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), |
|
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), |
|
"Nandodeomkar": pipeline( |
|
"image-classification", |
|
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388" |
|
) |
|
} |
|
except Exception as e: |
|
logger.error(f"Error loading models: {str(e)}") |
|
raise |
|
|
|
|
|
try: |
|
models = load_models() |
|
logger.info("Models loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to load models: {str(e)}") |
|
models = None |
|
|
|
def draw_boxes(image, predictions, threshold=0.6): |
|
""" |
|
Draw bounding boxes and labels on the image for detected fractures. |
|
|
|
Args: |
|
image (PIL.Image): Input image |
|
predictions (list): List of predictions from the model |
|
threshold (float): Confidence threshold for filtering predictions |
|
|
|
Returns: |
|
tuple: (annotated image, filtered predictions) |
|
""" |
|
draw = ImageDraw.Draw(image) |
|
filtered_preds = [p for p in predictions if p['score'] >= threshold] |
|
|
|
for pred in filtered_preds: |
|
box = pred['box'] |
|
label = f"{pred['label']} ({pred['score']:.2%})" |
|
|
|
|
|
draw.rectangle( |
|
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], |
|
outline="red", |
|
width=2 |
|
) |
|
|
|
|
|
draw.text( |
|
(box['xmin'], box['ymin'] - 10), |
|
label, |
|
fill="red" |
|
) |
|
|
|
return image, filtered_preds |
|
|
|
def process_image(image, confidence_threshold): |
|
""" |
|
Process an image through all models and return combined results. |
|
|
|
Args: |
|
image (PIL.Image): Input image |
|
confidence_threshold (float): Confidence threshold for filtering predictions |
|
|
|
Returns: |
|
dict: Combined results from all models |
|
""" |
|
try: |
|
|
|
detection_preds = models["D3STRON"](image) |
|
result_image = image.copy() |
|
result_image, filtered_detections = draw_boxes( |
|
result_image, |
|
detection_preds, |
|
confidence_threshold |
|
) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
result_image.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
img_b64 = base64.b64encode(img_byte_arr).decode() |
|
|
|
|
|
class_results = {} |
|
|
|
|
|
try: |
|
heem2_result = models["Heem2"](image) |
|
class_results["Heem2"] = heem2_result |
|
except Exception as e: |
|
logger.error(f"Error in Heem2 model: {str(e)}") |
|
class_results["Heem2"] = {"error": str(e)} |
|
|
|
|
|
try: |
|
nando_result = models["Nandodeomkar"](image) |
|
class_results["Nandodeomkar"] = nando_result |
|
except Exception as e: |
|
logger.error(f"Error in Nandodeomkar model: {str(e)}") |
|
class_results["Nandodeomkar"] = {"error": str(e)} |
|
|
|
return { |
|
"success": True, |
|
"detections": filtered_detections, |
|
"classifications": class_results, |
|
"image": img_b64 |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {str(e)}") |
|
raise |
|
|
|
|
|
@app.post("/detect") |
|
@app.post("/api/predict") |
|
async def detect_fracture( |
|
file: UploadFile = File(...), |
|
confidence: float = Form(default=0.6) |
|
): |
|
""" |
|
Endpoint for fracture detection in X-ray images. |
|
|
|
Args: |
|
file (UploadFile): Uploaded image file |
|
confidence (float): Confidence threshold for predictions |
|
|
|
Returns: |
|
JSONResponse: Detection results including annotated image |
|
""" |
|
logger.info(f"Received request with confidence threshold: {confidence}") |
|
|
|
try: |
|
|
|
if not 0 <= confidence <= 1: |
|
return JSONResponse( |
|
status_code=400, |
|
content={ |
|
"success": False, |
|
"error": "Confidence threshold must be between 0 and 1" |
|
} |
|
) |
|
|
|
|
|
contents = await file.read() |
|
try: |
|
image = Image.open(io.BytesIO(contents)) |
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=400, |
|
content={ |
|
"success": False, |
|
"error": f"Invalid image file: {str(e)}" |
|
} |
|
) |
|
|
|
|
|
try: |
|
results = process_image(image, confidence) |
|
logger.info("Image processed successfully") |
|
return JSONResponse(content=results) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {str(e)}") |
|
return JSONResponse( |
|
status_code=500, |
|
content={ |
|
"success": False, |
|
"error": f"Error processing image: {str(e)}" |
|
} |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Unexpected error: {str(e)}") |
|
return JSONResponse( |
|
status_code=500, |
|
content={ |
|
"success": False, |
|
"error": f"Unexpected error: {str(e)}" |
|
} |
|
) |
|
|
|
|
|
def main(): |
|
st.title("🦴 Fracture Detection System") |
|
st.write("Upload an X-ray image to detect potential fractures") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Upload X-ray image", |
|
type=['png', 'jpg', 'jpeg'] |
|
) |
|
|
|
|
|
confidence = st.slider( |
|
"Confidence Threshold", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.6, |
|
step=0.05 |
|
) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Original Image", use_column_width=True) |
|
|
|
if st.button("Analyze Image"): |
|
try: |
|
|
|
results = process_image(image, confidence) |
|
|
|
if results["success"]: |
|
|
|
st.success("Analysis completed successfully!") |
|
|
|
|
|
annotated_image = Image.open(io.BytesIO(base64.b64decode(results["image"]))) |
|
st.image(annotated_image, caption="Detected Fractures", use_column_width=True) |
|
|
|
|
|
if results["detections"]: |
|
st.subheader("Detected Fractures") |
|
for det in results["detections"]: |
|
st.write(f"- {det['label']}: {det['score']:.2%} confidence") |
|
|
|
|
|
st.subheader("Classification Results") |
|
for model, preds in results["classifications"].items(): |
|
st.write(f"**{model} Model:**") |
|
st.json(preds) |
|
else: |
|
st.error("Analysis failed. Please try again.") |
|
|
|
except Exception as e: |
|
st.error(f"Error during analysis: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |