File size: 4,866 Bytes
d1b509c
 
5092eb8
 
d1b509c
5092eb8
 
3a8ed2e
 
5092eb8
d1b509c
5092eb8
 
971f77f
d1b509c
 
5092eb8
691b7fe
5092eb8
d1b509c
691b7fe
 
 
 
 
 
5092eb8
d1b509c
691b7fe
d1b509c
 
 
 
 
 
691b7fe
d1b509c
691b7fe
d1b509c
971f77f
d1b509c
 
 
691b7fe
d1b509c
 
 
971f77f
691b7fe
3a8ed2e
d1b509c
3a8ed2e
d1b509c
3a8ed2e
 
 
691b7fe
3a8ed2e
 
d1b509c
691b7fe
 
d1b509c
 
3a8ed2e
 
d1b509c
691b7fe
 
 
 
 
 
 
 
d1b509c
691b7fe
971f77f
691b7fe
 
d1b509c
971f77f
691b7fe
d1b509c
971f77f
691b7fe
971f77f
3a8ed2e
d1b509c
5092eb8
d1b509c
5092eb8
d1b509c
5092eb8
7ce08d2
691b7fe
 
 
7ce08d2
691b7fe
 
 
 
 
 
 
 
 
d1b509c
 
691b7fe
971f77f
 
691b7fe
5092eb8
691b7fe
5092eb8
d1b509c
5092eb8
691b7fe
 
3a8ed2e
691b7fe
d1b509c
3a8ed2e
d1b509c
 
 
 
b4e137f
d1b509c
 
5092eb8
 
691b7fe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow warnings
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import requests
import json

# ================== MODEL LOADING ==================
try:
    model = load_model('wound_classifier_model_googlenet.h5')
    print("✅ Model loaded successfully")
    print(f"ℹ️ Input shape: {model.input_shape}")
    print(f"ℹ️ Output shape: {model.output_shape}")
except Exception as e:
    raise RuntimeError(f"❌ Model loading failed: {str(e)}")

# ================== CLASS LABELS ==================
CLASS_LABELS = [
    "Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
    "Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
    "chickenpox", "cutaneous larva migrans", "impetigo", "nail fungus",
    "ringworm", "shingles", "tooth discoloration", "ulcer"
]

# Verify model compatibility
assert len(CLASS_LABELS) == model.output_shape[-1], \
    f"Class mismatch: Model expects {model.output_shape[-1]} classes, found {len(CLASS_LABELS)}"

# ================== OPENROUTER CONFIG ==================
OPENROUTER_API_KEY = "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"
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL_NAME = "mistralai/mistral-7b-instruct"

# ================== IMAGE PROCESSING ==================
def preprocess_image(image, target_size=(224, 224)):
    """Process and validate input images"""
    try:
        if not image:
            raise ValueError("🚨 No image provided")
        
        image = image.convert("RGB").resize(target_size)
        array = np.array(image) / 255.0
        print(f"🖼️ Image processed: Shape {array.shape}")
        return array
    except Exception as e:
        raise RuntimeError(f"🖼️ Image processing failed: {str(e)}")

# ================== MEDICAL GUIDELINES ==================
def get_medical_guidelines(wound_type):
    """Fetch treatment guidelines from OpenRouter API"""
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
        "HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
    }
    
    prompt = f"""As a medical expert, provide treatment guidelines for {wound_type}:
    - First aid steps
    - Precautions
    - When to seek professional help
    Use clear, simple language without markdown."""
    
    try:
        print(f"📡 Sending API request for {wound_type}...")
        response = requests.post(
            OPENROUTER_API_URL,
            headers=headers,
            json={
                "model": MODEL_NAME,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.5
            },
            timeout=20
        )
        
        response.raise_for_status()
        result = response.json()
        print("🔧 Raw API response:", json.dumps(result, indent=2))
        
        if not result.get("choices"):
            return "⚠️ API response format unexpected"
            
        return result["choices"][0]["message"]["content"]
        
    except Exception as e:
        return f"⚠️ Guidelines unavailable: {str(e)}"

# ================== MAIN PREDICTION ==================
def predict(image):
    """Complete prediction pipeline"""
    try:
        # Preprocess image
        processed_img = preprocess_image(image)
        input_tensor = np.expand_dims(processed_img, axis=0)
        
        # Make prediction
        predictions = model.predict(input_tensor)[0]
        sorted_indices = np.argsort(predictions)[::-1]  # Descending order
        
        # Format results
        results = {
            CLASS_LABELS[i]: float(predictions[i])
            for i in sorted_indices[:3]  # Top 3 predictions
        }
        top_class = CLASS_LABELS[sorted_indices[0]]
        
        # Get guidelines
        guidelines = get_medical_guidelines(top_class)
        
        return results, guidelines
        
    except Exception as e:
        return {f"🚨 Error": str(e)}, ""

# ================== GRADIO INTERFACE ==================
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Wound Image"),
    outputs=[
        gr.Label(label="Top Predictions", num_top_classes=3),
        gr.Textbox(label="Treatment Guidelines", lines=8)
    ],
    title="AI Wound Classification System",
    description="Identifies 18 wound types and provides treatment recommendations",
    allow_flagging="never",
    examples=[
        f for f in ["abrasion.jpg", "burn.png", "bruise.png", "chicken-pox.png", "cut.png"] 
        if os.path.exists(f)
    ]
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)