File size: 5,063 Bytes
d1b509c
 
5092eb8
 
d1b509c
5092eb8
 
3a8ed2e
 
5092eb8
d1b509c
5092eb8
 
70e7566
5092eb8
70e7566
5092eb8
d1b509c
691b7fe
 
 
 
 
 
5092eb8
d1b509c
691b7fe
d1b509c
 
 
 
 
 
691b7fe
d1b509c
691b7fe
d1b509c
971f77f
d1b509c
70e7566
d1b509c
691b7fe
d1b509c
70e7566
d1b509c
971f77f
70e7566
3a8ed2e
d1b509c
3a8ed2e
d1b509c
3a8ed2e
 
 
691b7fe
3a8ed2e
 
d1b509c
691b7fe
 
d1b509c
 
3a8ed2e
 
70e7566
691b7fe
 
 
 
 
 
 
 
d1b509c
691b7fe
971f77f
691b7fe
 
70e7566
971f77f
691b7fe
70e7566
971f77f
691b7fe
971f77f
3a8ed2e
70e7566
5092eb8
d1b509c
5092eb8
d1b509c
5092eb8
7ce08d2
691b7fe
 
 
7ce08d2
691b7fe
 
 
 
 
 
 
 
 
d1b509c
 
691b7fe
971f77f
 
691b7fe
5092eb8
70e7566
5092eb8
d1b509c
850d90a
 
70e7566
850d90a
 
70e7566
 
 
 
850d90a
70e7566
850d90a
 
 
 
 
70e7566
850d90a
5092eb8
691b7fe
850d90a
 
 
 
70e7566
0a0bed4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow warnings
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import requests
import json

# ================== MODEL LOADING ==================
try:
    model = load_model('wound_classifier_model_googlenet.h5')
    print("✅ Model loaded successfully")
except Exception as e:
    raise RuntimeError(f"❌ Model loading failed: {str(e)}")

# ================== CLASS LABELS ==================
CLASS_LABELS = [
    "Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
    "Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
    "chickenpox", "cutaneous larva migrans", "impetigo", "nail fungus",
    "ringworm", "shingles", "tooth discoloration", "ulcer"
]

# Verify model compatibility
assert len(CLASS_LABELS) == model.output_shape[-1], \
    f"Class mismatch: Model expects {model.output_shape[-1]} classes, found {len(CLASS_LABELS)}"

# ================== OPENROUTER CONFIG ==================
OPENROUTER_API_KEY = "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"
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL_NAME = "mistralai/mistral-7b-instruct"

# ================== IMAGE PROCESSING ==================
def preprocess_image(image, target_size=(224, 224)):
    """Process and validate input images"""
    try:
        if not image:
            raise ValueError("🚨 No image provided")
        
        image = image.convert("RGB").resize(target_size)
        array = np.array(image) / 255.0
        print(f"🖼️ Image processed: Shape {array.shape}")
        return array
    except Exception as e:
        raise RuntimeError(f"🖼️ Image processing failed: {str(e)}")

# ================== MEDICAL GUIDELINES ==================
def get_medical_guidelines(wound_type):
    """Fetch treatment guidelines from OpenRouter API"""
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
        "HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
    }
    
    prompt = f"""As a medical expert, provide treatment guidelines for {wound_type}:
    - First aid steps
    - Precautions
    - When to seek professional help
    Use clear, simple language without markdown."""
    
    try:
        print(f"📡 Sending API request for {wound_type}...")
        response = requests.post(
            OPENROUTER_API_URL,
            headers=headers,
            json={
                "model": MODEL_NAME,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.5
            },
            timeout=20
        )
        
        response.raise_for_status()
        result = response.json()
        print("🔧 Raw API response:", json.dumps(result, indent=2))
        
        if not result.get("choices"):
            return "⚠️ API response format unexpected"
            
        return result["choices"][0]["message"]["content"]
        
    except Exception as e:
        return f"⚠️ Guidelines unavailable: {str(e)}"

# ================== MAIN PREDICTION ==================
def predict(image):
    """Complete prediction pipeline"""
    try:
        # Preprocess image
        processed_img = preprocess_image(image)
        input_tensor = np.expand_dims(processed_img, axis=0)
        
        # Make prediction
        predictions = model.predict(input_tensor)[0]
        sorted_indices = np.argsort(predictions)[::-1]  # Descending order
        
        # Format results
        results = {
            CLASS_LABELS[i]: float(predictions[i])
            for i in sorted_indices[:3]  # Top 3 predictions
        }
        top_class = CLASS_LABELS[sorted_indices[0]]
        
        # Get guidelines
        guidelines = get_medical_guidelines(top_class)
        
        return results, guidelines
        
    except Exception as e:
        return {f"🚨 Error": str(e)}, ""

# ================== GRADIO INTERFACE ==================
def create_interface():
    with gr.Blocks(title="AI Wound Classifier") as demo:
        gr.Markdown("# 🩹 AI-Powered Wound Classification System")
        gr.Markdown("Upload a wound image or take a photo using your camera")

        file_input = gr.Image(type="pil", label="Upload Wound Image")
        submit_btn = gr.Button("Analyze Now", variant="primary")
        output_label = gr.Label(label="Top Predictions", num_top_classes=3)
        output_guidelines = gr.Textbox(label="Treatment Guidelines", lines=8)

        # Connect input to processing
        submit_btn.click(
            fn=predict,
            inputs=[file_input],
            outputs=[output_label, output_guidelines]
        )
    
    return demo

if __name__ == "__main__":
    iface = create_interface()
    iface.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True  # Set to False if you do not want a public link
    )