Wound_Treatment / app.py
MahatirTusher's picture
Update app.py
b4e137f verified
raw
history blame
4.87 kB
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import requests
import json
# ================== MODEL LOADING ==================
try:
model = load_model('wound_classifier_model_googlenet.h5')
print("โœ… Model loaded successfully")
print(f"โ„น๏ธ Input shape: {model.input_shape}")
print(f"โ„น๏ธ Output shape: {model.output_shape}")
except Exception as e:
raise RuntimeError(f"โŒ Model loading failed: {str(e)}")
# ================== CLASS LABELS ==================
CLASS_LABELS = [
"Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
"Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
"chickenpox", "cutaneous larva migrans", "impetigo", "nail fungus",
"ringworm", "shingles", "tooth discoloration", "ulcer"
]
# Verify model compatibility
assert len(CLASS_LABELS) == model.output_shape[-1], \
f"Class mismatch: Model expects {model.output_shape[-1]} classes, found {len(CLASS_LABELS)}"
# ================== OPENROUTER CONFIG ==================
OPENROUTER_API_KEY = "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"
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL_NAME = "mistralai/mistral-7b-instruct"
# ================== IMAGE PROCESSING ==================
def preprocess_image(image, target_size=(224, 224)):
"""Process and validate input images"""
try:
if not image:
raise ValueError("๐Ÿšจ No image provided")
image = image.convert("RGB").resize(target_size)
array = np.array(image) / 255.0
print(f"๐Ÿ–ผ๏ธ Image processed: Shape {array.shape}")
return array
except Exception as e:
raise RuntimeError(f"๐Ÿ–ผ๏ธ Image processing failed: {str(e)}")
# ================== MEDICAL GUIDELINES ==================
def get_medical_guidelines(wound_type):
"""Fetch treatment guidelines from OpenRouter API"""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
}
prompt = f"""As a medical expert, provide treatment guidelines for {wound_type}:
- First aid steps
- Precautions
- When to seek professional help
Use clear, simple language without markdown."""
try:
print(f"๐Ÿ“ก Sending API request for {wound_type}...")
response = requests.post(
OPENROUTER_API_URL,
headers=headers,
json={
"model": MODEL_NAME,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5
},
timeout=20
)
response.raise_for_status()
result = response.json()
print("๐Ÿ”ง Raw API response:", json.dumps(result, indent=2))
if not result.get("choices"):
return "โš ๏ธ API response format unexpected"
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"โš ๏ธ Guidelines unavailable: {str(e)}"
# ================== MAIN PREDICTION ==================
def predict(image):
"""Complete prediction pipeline"""
try:
# Preprocess image
processed_img = preprocess_image(image)
input_tensor = np.expand_dims(processed_img, axis=0)
# Make prediction
predictions = model.predict(input_tensor)[0]
sorted_indices = np.argsort(predictions)[::-1] # Descending order
# Format results
results = {
CLASS_LABELS[i]: float(predictions[i])
for i in sorted_indices[:3] # Top 3 predictions
}
top_class = CLASS_LABELS[sorted_indices[0]]
# Get guidelines
guidelines = get_medical_guidelines(top_class)
return results, guidelines
except Exception as e:
return {f"๐Ÿšจ Error": str(e)}, ""
# ================== GRADIO INTERFACE ==================
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Wound Image"),
outputs=[
gr.Label(label="Top Predictions", num_top_classes=3),
gr.Textbox(label="Treatment Guidelines", lines=8)
],
title="AI Wound Classification System",
description="Identifies 18 wound types and provides treatment recommendations",
allow_flagging="never",
examples=[
f for f in ["abrasion.jpg", "burn.png", "bruise.png", "chicken-pox.png", "cut.png"]
if os.path.exists(f)
]
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)