File size: 3,161 Bytes
a326b94 3bb1400 88fb5fa 04a7bfd af17427 04a7bfd af17427 0000f4a 04a7bfd af17427 6cc7ff9 04a7bfd 005d8cf af17427 005d8cf 88fb5fa af17427 806ecee af17427 3bb1400 04a7bfd 88fb5fa 04a7bfd 88fb5fa 04a7bfd 88fb5fa 04a7bfd 6cc7ff9 88fb5fa af17427 04a7bfd 0000f4a af17427 04a7bfd af17427 04a7bfd af17427 04a7bfd af17427 04a7bfd af17427 78d26e0 af17427 04a7bfd af17427 04a7bfd af17427 04a7bfd 88fb5fa af17427 88fb5fa af17427 88fb5fa af17427 6ea5ee2 3bb1400 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import streamlit as st
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import io
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import json
from starlette.responses import JSONResponse
# FastAPI app
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load models
@st.cache_resource
def load_models():
return {
"D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"Nandodeomkar": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
models = load_models()
def draw_boxes(image, predictions, threshold=0.6):
draw = ImageDraw.Draw(image)
filtered_preds = [p for p in predictions if p['score'] >= threshold]
for pred in filtered_preds:
box = pred['box']
label = f"{pred['label']} ({pred['score']:.2%})"
draw.rectangle(
[(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
outline="red",
width=2
)
draw.text((box['xmin'], box['ymin']), label, fill="red")
return image, filtered_preds
# API Endpoint
@app.post("/detect")
async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
try:
# Read and process image
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Get predictions from all models
results = {}
# Object detection models
detection_preds = models["D3STRON"](image)
result_image = image.copy()
result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
# Save result image
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
img_b64 = base64.b64encode(img_byte_arr).decode()
# Classification models
class_results = {
"Heem2": models["Heem2"](image),
"Nandodeomkar": models["Nandodeomkar"](image)
}
return JSONResponse({
"success": True,
"detections": filtered_detections,
"classifications": class_results,
"image": img_b64
})
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
})
# Streamlit UI
def main():
st.title("🦴 Fraktur Detektion")
# UI elements...
uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
if uploaded_file:
# Process image and display results...
pass
if __name__ == "__main__":
main() |