MahatirTusher commited on
Commit
d1b509c
·
verified ·
1 Parent(s): 7303271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -34
app.py CHANGED
@@ -1,25 +1,23 @@
 
 
1
  import gradio as gr
2
  import tensorflow as tf
 
3
  import numpy as np
4
  from PIL import Image
5
  import requests
6
  import json
7
- import os
8
 
9
- # Load the model with enhanced error handling
10
  try:
11
  model = load_model('wound_classifier_model_googlenet.h5')
12
  print("✅ Model loaded successfully")
13
- print(f"ℹ️ Model expects {model.output_shape[-1]} output classes") # Should be 18
 
14
  except Exception as e:
15
  raise RuntimeError(f"❌ Model loading failed: {str(e)}")
16
 
17
- # OpenRouter configuration
18
- OPENROUTER_API_KEY = os.getenv("OPENROUTER_KEY", "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")
19
- OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
20
- MODEL_NAME = "mistralai/mistral-7b-instruct"
21
-
22
- # Class labels from your classes.txt
23
  CLASS_LABELS = [
24
  "Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
25
  "Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
@@ -27,36 +25,46 @@ CLASS_LABELS = [
27
  "ringworm", "shingles", "tooth discoloration", "ulcer"
28
  ]
29
 
30
- # Verify class labels match model output
31
  assert len(CLASS_LABELS) == model.output_shape[-1], \
32
- f"Class labels mismatch: {len(CLASS_LABELS)} vs {model.output_shape[-1]}"
 
 
 
 
 
33
 
 
34
  def preprocess_image(image, target_size=(224, 224)):
35
- """Enhanced image preprocessing with validation"""
36
- if not image:
37
- raise ValueError("🖼️ No image provided")
38
-
39
  try:
 
 
 
40
  image = image.convert("RGB").resize(target_size)
41
- return np.array(image) / 255.0
 
 
42
  except Exception as e:
43
  raise RuntimeError(f"🖼️ Image processing failed: {str(e)}")
44
 
 
45
  def get_medical_guidelines(wound_type):
46
- """Robust API handler with better error reporting"""
47
  headers = {
48
  "Authorization": f"Bearer {OPENROUTER_API_KEY}",
49
  "Content-Type": "application/json",
50
  "HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
51
  }
52
 
53
- prompt = f"""As a medical professional, provide concise guidelines for {wound_type}:
54
  - First aid steps
55
  - Precautions
56
- - When to seek help
57
- Avoid markdown, use simple language."""
58
 
59
  try:
 
60
  response = requests.post(
61
  OPENROUTER_API_URL,
62
  headers=headers,
@@ -65,24 +73,24 @@ def get_medical_guidelines(wound_type):
65
  "messages": [{"role": "user", "content": prompt}],
66
  "temperature": 0.5
67
  },
68
- timeout=15
69
  )
70
 
71
  response.raise_for_status()
72
  result = response.json()
 
73
 
74
  if not result.get("choices"):
75
- return f"⚠️ API Error: Unexpected response format"
76
 
77
  return result["choices"][0]["message"]["content"]
78
 
79
- except requests.exceptions.RequestException as e:
80
- return f"🔌 Connection Error: {str(e)}"
81
  except Exception as e:
82
- return f"⚠️ Processing Error: {str(e)}"
83
 
 
84
  def predict(image):
85
- """Main prediction pipeline with validation"""
86
  try:
87
  # Preprocess image
88
  processed_img = preprocess_image(image)
@@ -97,9 +105,9 @@ def predict(image):
97
  CLASS_LABELS[i]: float(predictions[i])
98
  for i in sorted_indices[:3] # Top 3 predictions
99
  }
100
-
101
- # Get guidelines for top prediction
102
  top_class = CLASS_LABELS[sorted_indices[0]]
 
 
103
  guidelines = get_medical_guidelines(top_class)
104
 
105
  return results, guidelines
@@ -107,18 +115,21 @@ def predict(image):
107
  except Exception as e:
108
  return {f"🚨 Error": str(e)}, ""
109
 
110
- # Gradio interface configuration
111
  iface = gr.Interface(
112
  fn=predict,
113
  inputs=gr.Image(type="pil", label="Upload Wound Image"),
114
  outputs=[
115
  gr.Label(label="Top Predictions", num_top_classes=3),
116
- gr.Textbox(label="Treatment Guidelines", lines=5)
117
  ],
118
- title="Advanced Wound Classification System",
119
- description="Identifies 18 wound types and provides treatment guidelines",
120
- examples=["./example_abrasion.jpg", "./example_burn.jpg"] if os.path.exists("example_abrasion.jpg") else None,
121
- allow_flagging="never"
 
 
 
122
  )
123
 
124
  if __name__ == "__main__":
 
1
+ import os
2
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings
3
  import gradio as gr
4
  import tensorflow as tf
5
+ from tensorflow.keras.models import load_model
6
  import numpy as np
7
  from PIL import Image
8
  import requests
9
  import json
 
10
 
11
+ # ================== MODEL LOADING ==================
12
  try:
13
  model = load_model('wound_classifier_model_googlenet.h5')
14
  print("✅ Model loaded successfully")
15
+ print(f"ℹ️ Input shape: {model.input_shape}")
16
+ print(f"ℹ️ Output shape: {model.output_shape}")
17
  except Exception as e:
18
  raise RuntimeError(f"❌ Model loading failed: {str(e)}")
19
 
20
+ # ================== CLASS LABELS ==================
 
 
 
 
 
21
  CLASS_LABELS = [
22
  "Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
23
  "Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
 
25
  "ringworm", "shingles", "tooth discoloration", "ulcer"
26
  ]
27
 
28
+ # Verify model compatibility
29
  assert len(CLASS_LABELS) == model.output_shape[-1], \
30
+ f"Class mismatch: Model expects {model.output_shape[-1]} classes, found {len(CLASS_LABELS)}"
31
+
32
+ # ================== OPENROUTER CONFIG ==================
33
+ OPENROUTER_API_KEY = "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"
34
+ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
35
+ MODEL_NAME = "mistralai/mistral-7b-instruct"
36
 
37
+ # ================== IMAGE PROCESSING ==================
38
  def preprocess_image(image, target_size=(224, 224)):
39
+ """Process and validate input images"""
 
 
 
40
  try:
41
+ if not image:
42
+ raise ValueError("🚨 No image provided")
43
+
44
  image = image.convert("RGB").resize(target_size)
45
+ array = np.array(image) / 255.0
46
+ print(f"🖼️ Image processed: Shape {array.shape}")
47
+ return array
48
  except Exception as e:
49
  raise RuntimeError(f"🖼️ Image processing failed: {str(e)}")
50
 
51
+ # ================== MEDICAL GUIDELINES ==================
52
  def get_medical_guidelines(wound_type):
53
+ """Fetch treatment guidelines from OpenRouter API"""
54
  headers = {
55
  "Authorization": f"Bearer {OPENROUTER_API_KEY}",
56
  "Content-Type": "application/json",
57
  "HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
58
  }
59
 
60
+ prompt = f"""As a medical expert, provide treatment guidelines for {wound_type}:
61
  - First aid steps
62
  - Precautions
63
+ - When to seek professional help
64
+ Use clear, simple language without markdown."""
65
 
66
  try:
67
+ print(f"📡 Sending API request for {wound_type}...")
68
  response = requests.post(
69
  OPENROUTER_API_URL,
70
  headers=headers,
 
73
  "messages": [{"role": "user", "content": prompt}],
74
  "temperature": 0.5
75
  },
76
+ timeout=20
77
  )
78
 
79
  response.raise_for_status()
80
  result = response.json()
81
+ print("🔧 Raw API response:", json.dumps(result, indent=2))
82
 
83
  if not result.get("choices"):
84
+ return "⚠️ API response format unexpected"
85
 
86
  return result["choices"][0]["message"]["content"]
87
 
 
 
88
  except Exception as e:
89
+ return f"⚠️ Guidelines unavailable: {str(e)}"
90
 
91
+ # ================== MAIN PREDICTION ==================
92
  def predict(image):
93
+ """Complete prediction pipeline"""
94
  try:
95
  # Preprocess image
96
  processed_img = preprocess_image(image)
 
105
  CLASS_LABELS[i]: float(predictions[i])
106
  for i in sorted_indices[:3] # Top 3 predictions
107
  }
 
 
108
  top_class = CLASS_LABELS[sorted_indices[0]]
109
+
110
+ # Get guidelines
111
  guidelines = get_medical_guidelines(top_class)
112
 
113
  return results, guidelines
 
115
  except Exception as e:
116
  return {f"🚨 Error": str(e)}, ""
117
 
118
+ # ================== GRADIO INTERFACE ==================
119
  iface = gr.Interface(
120
  fn=predict,
121
  inputs=gr.Image(type="pil", label="Upload Wound Image"),
122
  outputs=[
123
  gr.Label(label="Top Predictions", num_top_classes=3),
124
+ gr.Textbox(label="Treatment Guidelines", lines=8)
125
  ],
126
+ title="AI Wound Classification System",
127
+ description="Identifies 18 wound types and provides treatment recommendations",
128
+ allow_flagging="never",
129
+ examples=[
130
+ f for f in ["abrasion.jpg", "burn.jpg"]
131
+ if os.path.exists(f)
132
+ ]
133
  )
134
 
135
  if __name__ == "__main__":