|
import streamlit as st |
|
from transformers import pipeline |
|
import torch |
|
from PIL import Image, ImageDraw |
|
import io |
|
import base64 |
|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import numpy as np |
|
import json |
|
from starlette.responses import JSONResponse |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
return { |
|
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), |
|
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), |
|
"Nandodeomkar": pipeline("image-classification", |
|
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") |
|
} |
|
|
|
models = load_models() |
|
|
|
def draw_boxes(image, predictions, threshold=0.6): |
|
draw = ImageDraw.Draw(image) |
|
filtered_preds = [p for p in predictions if p['score'] >= threshold] |
|
|
|
for pred in filtered_preds: |
|
box = pred['box'] |
|
label = f"{pred['label']} ({pred['score']:.2%})" |
|
|
|
draw.rectangle( |
|
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])], |
|
outline="red", |
|
width=2 |
|
) |
|
|
|
draw.text((box['xmin'], box['ymin']), label, fill="red") |
|
|
|
return image, filtered_preds |
|
|
|
|
|
@app.post("/detect") |
|
async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6): |
|
try: |
|
|
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)) |
|
|
|
|
|
results = {} |
|
|
|
|
|
detection_preds = models["D3STRON"](image) |
|
result_image = image.copy() |
|
result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
result_image.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
img_b64 = base64.b64encode(img_byte_arr).decode() |
|
|
|
|
|
class_results = { |
|
"Heem2": models["Heem2"](image), |
|
"Nandodeomkar": models["Nandodeomkar"](image) |
|
} |
|
|
|
return JSONResponse({ |
|
"success": True, |
|
"detections": filtered_detections, |
|
"classifications": class_results, |
|
"image": img_b64 |
|
}) |
|
|
|
except Exception as e: |
|
return JSONResponse({ |
|
"success": False, |
|
"error": str(e) |
|
}) |
|
|
|
|
|
def main(): |
|
st.title("🦴 Fraktur Detektion") |
|
|
|
|
|
uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg']) |
|
confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05) |
|
|
|
if uploaded_file: |
|
|
|
pass |
|
|
|
if __name__ == "__main__": |
|
main() |