bfd-rg / app.py
3dredstone's picture
Update app.py
39557f3 verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import HTMLResponse, StreamingResponse
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
import io
import uvicorn
import base64
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as ReportLabImage
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.enums import TA_CENTER
from reportlab.lib.units import inch
app = FastAPI()
# Load models
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
models = load_models()
def translate_label(label):
# Keep translations for internal use if needed, but for the PDF we'll use English
translations = {
"fracture": "Fracture",
"no fracture": "No Fracture",
"normal": "Normal",
"abnormal": "Abnormal",
"F1": "Fracture", # Assuming F1 also means fracture
"NF": "No Fracture" # Assuming NF means no fracture
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
# Label in English
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
try:
text_bbox = draw.textbbox((box['xmin'], box['ymin'] - 20), label)
except AttributeError:
font_size = 10
text_width = len(label) * font_size * 0.6
text_height = font_size * 1.2
text_bbox = (box['xmin'], box['ymin'] - text_height, box['xmin'] + text_width, box['ymin'])
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
COMMON_STYLES = """
body {
font-family: system-ui, -apple-system, sans-serif;
background: #f0f2f5;
margin: 0;
padding: 20px;
color: #1a1a1a;
}
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: transparent;
}
::-webkit-scrollbar-thumb {
background-color: rgba(156, 163, 175, 0.5);
border-radius: 4px;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.button {
background: #2d2d2d;
color: white;
border: none;
padding: 12px 30px;
border-radius: 8px;
cursor: pointer;
font-size: 1.1em;
transition: all 0.3s ease;
position: relative;
}
.button:hover {
background: #404040;
}
@keyframes progress {
0% { width: 0; }
100% { width: 100%; }
}
.button-progress {
position: absolute;
bottom: 0;
left: 0;
height: 4px;
background: rgba(255, 255, 255, 0.5);
width: 0;
}
.button:active .button-progress {
animation: progress 2s linear forwards;
}
img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
@keyframes blink {
0% { opacity: 1; }
50% { opacity: 0; }
100% { opacity: 1; }
}
#loading {
display: none;
color: white;
margin-top: 10px;
animation: blink 1s infinite;
text-align: center;
}
"""
@app.get("/", response_class=HTMLResponse)
async def main():
content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Fracture Detection</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.upload-section {{
background: #2d2d2d;
padding: 40px;
border-radius: 12px;
margin: 20px 0;
text-align: center;
border: 2px dashed #404040;
transition: all 0.3s ease;
color: white;
}}
.upload-section:hover {{
border-color: #555;
}}
input[type="file"] {{
font-size: 1.1em;
margin: 20px 0;
color: white;
}}
input[type="file"]::file-selector-button {{
font-size: 1em;
padding: 10px 20px;
border-radius: 8px;
border: 1px solid #404040;
background: #2d2d2d;
color: white;
transition: all 0.3s ease;
cursor: pointer;
}}
input[type="file"]::file-selector-button:hover {{
background: #404040;
}}
.confidence-slider {{
width: 100%;
max-width: 300px;
margin: 20px auto;
}}
input[type="range"] {{
width: 100%;
height: 8px;
border-radius: 4px;
background: #404040;
outline: none;
transition: all 0.3s ease;
-webkit-appearance: none;
}}
input[type="range"]::-webkit-slider-thumb {{
-webkit-appearance: none;
width: 20px;
height: 20px;
border-radius: 50%;
background: white;
cursor: pointer;
border: none;
}}
.input-field {{
margin-bottom: 20px;
}}
.input-field label {{
display: block;
margin-bottom: 5px;
font-size: 1.1em;
}}
.input-field input[type="text"] {{
width: calc(100% - 20px);
padding: 10px;
border-radius: 5px;
border: 1px solid #ccc;
background: #fff;
color: #1a1a1a;
font-size: 1em;
}}
</style>
</head>
<body>
<div class="container">
<div class="upload-section">
<form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';">
<div class="input-field">
<label for="patient_name">Patient Name:</label>
<input type="text" id="patient_name" name="patient_name" required>
</div>
<div>
<input type="file" name="file" accept="image/*" required>
</div>
<div class="confidence-slider">
<label for="threshold">Confidence Threshold: <span id="thresholdValue">0.60</span></label>
<input type="range" id="threshold" name="threshold"
min="0" max="1" step="0.05" value="0.60"
oninput="document.getElementById('thresholdValue').textContent = parseFloat(this.value).toFixed(2)">
</div>
<button type="submit" class="button">
Analyze & Generate PDF
<div class="button-progress"></div>
</button>
<div id="loading">Loading...</div>
</form>
</div>
</div>
</body>
</html>
"""
return content
@app.post("/analyze", response_class=StreamingResponse)
async def analyze_file(patient_name: str = Form(...), file: UploadFile = File(...), threshold: float = Form(0.6)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB") # Ensure RGB for PDF
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions_locator if p['score'] >= threshold]
if filtered_preds:
result_image = draw_boxes(image, filtered_preds)
else:
result_image = image
# Generate PDF
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=letter)
styles = getSampleStyleSheet()
centered_style = ParagraphStyle(
name='Centered',
parent=styles['Normal'],
alignment=TA_CENTER,
fontSize=12,
leading=14
)
heading_style = ParagraphStyle(
name='Heading',
parent=styles['h1'],
alignment=TA_CENTER,
fontSize=24,
spaceAfter=20
)
subheading_style = ParagraphStyle(
name='SubHeading',
parent=styles['h2'],
alignment=TA_CENTER,
fontSize=16,
spaceAfter=10
)
report_text_style = ParagraphStyle(
name='ReportText',
parent=styles['Normal'],
alignment=TA_CENTER,
fontSize=12,
spaceAfter=5
)
story = []
story.append(Paragraph("<b>Fracture Detection Report</b>", heading_style))
story.append(Spacer(1, 0.2 * inch))
story.append(Paragraph(f"<b>Patient Name:</b> {patient_name}", subheading_style))
story.append(Spacer(1, 0.4 * inch))
# KnochenWächter results
story.append(Paragraph("<b>KnochenWächter Results:</b>", subheading_style))
for pred in predictions_watcher:
story.append(Paragraph(
f"{translate_label(pred['label'])}: {pred['score']:.1%}",
report_text_style
))
story.append(Spacer(1, 0.2 * inch))
# RöntgenMeister results
story.append(Paragraph("<b>RöntgenMeister Results:</b>", subheading_style))
for pred in predictions_master:
story.append(Paragraph(
f"{translate_label(pred['label'])}: {pred['score']:.1%}",
report_text_style
))
story.append(Spacer(1, 0.4 * inch))
# Analyzed Image
story.append(Paragraph("<b>X-ray Image Analysis:</b>", subheading_style))
img_buffer = io.BytesIO()
result_image.save(img_buffer, format="PNG")
img_buffer.seek(0)
img_rl = ReportLabImage(img_buffer)
img_width, img_height = img_rl.drawWidth, img_rl.drawHeight
aspect_ratio = img_height / img_width
max_width = 5 * inch
if img_width > max_width:
img_rl.drawWidth = max_width
img_rl.drawHeight = max_width * aspect_ratio
img_rl.hAlign = 'CENTER'
story.append(img_rl)
story.append(Spacer(1, 0.4 * inch))
# Final report text based on object detection
if filtered_preds:
story.append(Paragraph(
"<b>The X-ray image analysis shows potential fracture localization.</b>",
report_text_style
))
for pred in filtered_preds:
score = pred['score']
temp = 36.5 + (score * 2.5)
story.append(Paragraph(
f"Detection: {translate_label(pred['label'])} with {score:.1%} confidence ({temp:.1f}°C)",
report_text_style
))
else:
story.append(Paragraph(
"<b>Based on object localization analysis, no fracture was detected with sufficient confidence.</b>",
report_text_style
))
story.append(Spacer(1, 0.2 * inch))
story.append(Paragraph("This is an automatically generated report and should be reviewed by a medical professional.", centered_style))
doc.build(story)
buffer.seek(0)
return StreamingResponse(buffer, media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename=Fracture_Report_{patient_name.replace(' ', '_')}.pdf"})
except Exception as e:
return HTMLResponse(f"""
<!DOCTYPE html>
<html>
<head>
<title>Error</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
{COMMON_STYLES}
.error-box {{
background: #fee2e2;
border: 1px solid #ef4444;
padding: 20px;
border-radius: 8px;
margin: 20px 0;
}}
</style>
</head>
<body>
<div class="container">
<div class="error-box">
<h3>Error</h3>
<p>{str(e)}</p>
</div>
<a href="/" class="button back-button">
← Back
<div class="button-progress"></div>
</a>
</div>
</body>
</html>
""")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)