Wound_Treatment / app.py
MahatirTusher's picture
Update app.py
04c0b64 verified
raw
history blame
5.06 kB
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import requests
import json
# ================== MODEL LOADING ==================
try:
model = load_model('wound_classifier_model_googlenet.h5')
print("✅ Model loaded successfully")
except Exception as e:
raise RuntimeError(f"❌ Model loading failed: {str(e)}")
# ================== CLASS LABELS ==================
CLASS_LABELS = [
"Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
"Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
"chickenpox", "cutaneous larva migrans", "impetigo", "nail fungus",
"ringworm", "shingles", "tooth discoloration", "ulcer"
]
# Verify model compatibility
assert len(CLASS_LABELS) == model.output_shape[-1], \
f"Class mismatch: Model expects {model.output_shape[-1]} classes, found {len(CLASS_LABELS)}"
# ================== OPENROUTER CONFIG ==================
OPENROUTER_API_KEY = "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"
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL_NAME = "mistralai/mistral-7b-instruct"
# ================== IMAGE PROCESSING ==================
def preprocess_image(image, target_size=(224, 224)):
"""Process and validate input images"""
try:
if not image:
raise ValueError("🚨 No image provided")
image = image.convert("RGB").resize(target_size)
array = np.array(image) / 255.0
print(f"🖼️ Image processed: Shape {array.shape}")
return array
except Exception as e:
raise RuntimeError(f"🖼️ Image processing failed: {str(e)}")
# ================== MEDICAL GUIDELINES ==================
def get_medical_guidelines(wound_type):
"""Fetch treatment guidelines from OpenRouter API"""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
}
prompt = f"""As a medical expert, provide treatment guidelines for {wound_type}:
- First aid steps
- Precautions
- When to seek professional help
Use clear, simple language without markdown."""
try:
print(f"📡 Sending API request for {wound_type}...")
response = requests.post(
OPENROUTER_API_URL,
headers=headers,
json={
"model": MODEL_NAME,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5
},
timeout=20
)
response.raise_for_status()
result = response.json()
print("🔧 Raw API response:", json.dumps(result, indent=2))
if not result.get("choices"):
return "⚠️ API response format unexpected"
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"⚠️ Guidelines unavailable: {str(e)}"
# ================== MAIN PREDICTION ==================
def predict(image):
"""Complete prediction pipeline"""
try:
# Preprocess image
processed_img = preprocess_image(image)
input_tensor = np.expand_dims(processed_img, axis=0)
# Make prediction
predictions = model.predict(input_tensor)[0]
sorted_indices = np.argsort(predictions)[::-1] # Descending order
# Format results
results = {
CLASS_LABELS[i]: float(predictions[i])
for i in sorted_indices[:3] # Top 3 predictions
}
top_class = CLASS_LABELS[sorted_indices[0]]
# Get guidelines
guidelines = get_medical_guidelines(top_class)
return results, guidelines
except Exception as e:
return {f"🚨 Error": str(e)}, ""
# ================== GRADIO INTERFACE ==================
def create_interface():
with gr.Blocks(title="AI Wound Classifier") as demo:
gr.Markdown("# 🩹 AI-Powered Wound Classification System")
gr.Markdown("Upload a wound image or take a photo using your camera")
file_input = gr.Image(type="pil", label="Upload Wound Image")
submit_btn = gr.Button("Analyze Now", variant="primary")
output_label = gr.Label(label="Top Predictions", num_top_classes=3)
output_guidelines = gr.Textbox(label="Treatment Guidelines", lines=8)
# Connect input to processing
submit_btn.click(
fn=predict,
inputs=[file_input],
outputs=[output_label, output_guidelines]
)
return demo
if __name__ == "__main__":
iface = create_interface()
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True # Set to False if you do not want a public link
)