radpid / app.py
yassonee's picture
Update app.py
af17427 verified
raw
history blame
3.16 kB
import streamlit as st
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import io
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import json
from starlette.responses import JSONResponse
# FastAPI app
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models
@st.cache_resource
def load_models():
return {
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"Nandodeomkar": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
models = load_models()
def draw_boxes(image, predictions, threshold=0.6):
draw = ImageDraw.Draw(image)
filtered_preds = [p for p in predictions if p['score'] >= threshold]
for pred in filtered_preds:
box = pred['box']
label = f"{pred['label']} ({pred['score']:.2%})"
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline="red",
width=2
)
draw.text((box['xmin'], box['ymin']), label, fill="red")
return image, filtered_preds
# API Endpoint
@app.post("/detect")
async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
try:
# Read and process image
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Get predictions from all models
results = {}
# Object detection models
detection_preds = models["D3STRON"](image)
result_image = image.copy()
result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
# Save result image
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
img_b64 = base64.b64encode(img_byte_arr).decode()
# Classification models
class_results = {
"Heem2": models["Heem2"](image),
"Nandodeomkar": models["Nandodeomkar"](image)
}
return JSONResponse({
"success": True,
"detections": filtered_detections,
"classifications": class_results,
"image": img_b64
})
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
})
# Streamlit UI
def main():
st.title("🦴 Fraktur Detektion")
# UI elements...
uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
if uploaded_file:
# Process image and display results...
pass
if __name__ == "__main__":
main()