File size: 3,161 Bytes
a326b94
3bb1400
88fb5fa
04a7bfd
 
 
af17427
 
04a7bfd
 
af17427
0000f4a
04a7bfd
af17427
6cc7ff9
04a7bfd
 
 
 
 
 
 
 
005d8cf
af17427
005d8cf
88fb5fa
af17427
 
 
 
 
 
806ecee
af17427
3bb1400
04a7bfd
88fb5fa
04a7bfd
 
 
88fb5fa
04a7bfd
88fb5fa
 
 
04a7bfd
6cc7ff9
88fb5fa
 
af17427
04a7bfd
 
0000f4a
af17427
 
 
04a7bfd
af17427
 
 
 
 
 
 
 
04a7bfd
 
af17427
04a7bfd
af17427
04a7bfd
 
 
 
 
af17427
 
 
 
 
78d26e0
af17427
04a7bfd
 
 
 
af17427
04a7bfd
 
af17427
 
 
 
04a7bfd
 
88fb5fa
af17427
88fb5fa
af17427
 
 
88fb5fa
af17427
 
 
6ea5ee2
3bb1400
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import io
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import json
from starlette.responses import JSONResponse

# FastAPI app
app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load models
@st.cache_resource
def load_models():
    return {
        "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
        "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
        "Nandodeomkar": pipeline("image-classification", 
            model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
    }

models = load_models()

def draw_boxes(image, predictions, threshold=0.6):
    draw = ImageDraw.Draw(image)
    filtered_preds = [p for p in predictions if p['score'] >= threshold]
    
    for pred in filtered_preds:
        box = pred['box']
        label = f"{pred['label']} ({pred['score']:.2%})"
        
        draw.rectangle(
            [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
            outline="red",
            width=2
        )
        
        draw.text((box['xmin'], box['ymin']), label, fill="red")
    
    return image, filtered_preds

# API Endpoint
@app.post("/detect")
async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
    try:
        # Read and process image
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
        
        # Get predictions from all models
        results = {}
        
        # Object detection models
        detection_preds = models["D3STRON"](image)
        result_image = image.copy()
        result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
        
        # Save result image
        img_byte_arr = io.BytesIO()
        result_image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        img_b64 = base64.b64encode(img_byte_arr).decode()
        
        # Classification models
        class_results = {
            "Heem2": models["Heem2"](image),
            "Nandodeomkar": models["Nandodeomkar"](image)
        }
        
        return JSONResponse({
            "success": True,
            "detections": filtered_detections,
            "classifications": class_results,
            "image": img_b64
        })
        
    except Exception as e:
        return JSONResponse({
            "success": False,
            "error": str(e)
        })

# Streamlit UI
def main():
    st.title("🦴 Fraktur Detektion")
    
    # UI elements...
    uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
    confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
    
    if uploaded_file:
        # Process image and display results...
        pass

if __name__ == "__main__":
    main()