MahatirTusher commited on
Commit
691b7fe
·
verified ·
1 Parent(s): 971f77f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -83
app.py CHANGED
@@ -1,130 +1,125 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
- from tensorflow.keras.models import load_model
4
  import numpy as np
5
  from PIL import Image
6
  import requests
7
  import json
 
8
 
9
- # Load the model
10
  try:
11
  model = load_model('wound_classifier_model_googlenet.h5')
12
  print("✅ Model loaded successfully")
 
13
  except Exception as e:
14
- raise RuntimeError(f"❌ Model loading failed: {e}")
15
 
16
- # OpenRouter.ai Configuration
17
- OPENROUTER_API_KEY = "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"
18
  OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
19
- MODEL_NAME = "mistralai/mistral-7b-instruct" # Updated model name
20
 
21
- input_shape = (224, 224, 3)
 
 
 
 
 
 
22
 
23
- def preprocess_image(image, target_size):
24
- """Preprocess the input image for the model."""
 
 
 
 
 
 
 
25
  try:
26
- if image is None:
27
- raise ValueError("No image provided")
28
- image = image.convert("RGB")
29
- image = image.resize(target_size)
30
  return np.array(image) / 255.0
31
  except Exception as e:
32
- print(f"⚠️ Image preprocessing error: {e}")
33
- raise
34
 
35
  def get_medical_guidelines(wound_type):
36
- """Fetch medical guidelines using OpenRouter.ai's API."""
37
  headers = {
38
  "Authorization": f"Bearer {OPENROUTER_API_KEY}",
39
  "Content-Type": "application/json",
40
- "HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment",
41
- "X-Title": "Wound Treatment Advisor"
42
  }
43
 
44
- prompt = f"""As a medical professional, provide detailed guidelines for treating a {wound_type} wound.
45
- Include:
46
- 1. First aid steps
47
- 2. Precautions
48
- 3. When to seek professional help
49
- Output in markdown with clear sections."""
50
-
51
- data = {
52
- "model": MODEL_NAME,
53
- "messages": [{"role": "user", "content": prompt}],
54
- "temperature": 0.5
55
- }
56
 
57
  try:
58
- print(f"🚀 Sending request to OpenRouter API for {wound_type}...")
59
- response = requests.post(OPENROUTER_API_URL, headers=headers, json=data, timeout=10)
60
- response.raise_for_status()
 
 
 
 
 
 
 
61
 
62
- response_json = response.json()
63
- print("🔧 Raw API response:", json.dumps(response_json, indent=2))
64
 
65
- if "choices" not in response_json:
66
- return "⚠️ API response format unexpected. Please check logs."
67
 
68
- return response_json["choices"][0]["message"]["content"]
69
 
70
- except requests.exceptions.HTTPError as e:
71
- print(f" HTTP Error: {e.response.status_code} - {e.response.text}")
72
- return f"API Error: {e.response.status_code} - Check console for details"
73
  except Exception as e:
74
- print(f"⚠️ General API error: {str(e)}")
75
- return f"Error: {str(e)}"
76
 
77
  def predict(image):
78
- """Main prediction function."""
79
  try:
80
  # Preprocess image
81
- input_data = preprocess_image(image, (input_shape[0], input_shape[1]))
82
- input_data = np.expand_dims(input_data, axis=0)
83
- print("🖼️ Image preprocessed successfully")
84
-
85
- # Load class labels
86
- try:
87
- with open('classes.txt', 'r') as file:
88
- class_labels = file.read().splitlines()
89
- print("📋 Class labels loaded:", class_labels)
90
- except Exception as e:
91
- raise RuntimeError(f"Class labels loading failed: {e}")
92
-
93
- # Verify model compatibility
94
- if len(class_labels) != model.output_shape[-1]:
95
- raise ValueError(f"Model expects {model.output_shape[-1]} classes, found {len(class_labels)}")
96
-
97
  # Make prediction
98
- predictions = model.predict(input_data)
99
- print("📊 Raw predictions:", predictions)
 
 
 
 
 
 
 
 
 
 
100
 
101
- results = {class_labels[i]: float(predictions[0][i])
102
- for i in range(len(class_labels))}
103
- predicted_class = max(results, key=results.get)
104
- print(f"🏆 Predicted class: {predicted_class}")
105
-
106
- # Get medical guidelines
107
- guidelines = get_medical_guidelines(predicted_class)
108
- print("📜 Guidelines generated successfully")
109
-
110
  return results, guidelines
111
-
112
  except Exception as e:
113
- print(f"🔥 Critical error in prediction: {str(e)}")
114
- return {"Error": str(e)}, ""
115
 
116
- # Gradio Interface
117
  iface = gr.Interface(
118
- fn=predict,
119
- inputs=gr.Image(type="pil", label="Upload Wound Image"),
120
  outputs=[
121
- gr.Label(num_top_classes=3, label="Classification Results"),
122
- gr.Markdown(label="Medical Guidelines")
123
  ],
124
- live=False,
125
- title="Wound Classification & Treatment Advisor",
126
- description="Upload a wound image for AI-powered classification and treatment guidelines.",
127
  allow_flagging="never"
128
  )
129
 
130
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
  import tensorflow as tf
 
3
  import numpy as np
4
  from PIL import Image
5
  import requests
6
  import json
7
+ import os
8
 
9
+ # Load the model with enhanced error handling
10
  try:
11
  model = load_model('wound_classifier_model_googlenet.h5')
12
  print("✅ Model loaded successfully")
13
+ print(f"ℹ️ Model expects {model.output_shape[-1]} output classes") # Should be 18
14
  except Exception as e:
15
+ raise RuntimeError(f"❌ Model loading failed: {str(e)}")
16
 
17
+ # OpenRouter configuration
18
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_KEY", "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")
19
  OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
20
+ MODEL_NAME = "mistralai/mistral-7b-instruct"
21
 
22
+ # Class labels from your classes.txt
23
+ CLASS_LABELS = [
24
+ "Abrasions", "Bruises", "Burns", "Cut", "Diabetic Wounds", "Gingivitis",
25
+ "Surgical Wounds", "Venous Wounds", "athlete foot", "cellulitis",
26
+ "chickenpox", "cutaneous larva migrans", "impetigo", "nail fungus",
27
+ "ringworm", "shingles", "tooth discoloration", "ulcer"
28
+ ]
29
 
30
+ # Verify class labels match model output
31
+ assert len(CLASS_LABELS) == model.output_shape[-1], \
32
+ f"Class labels mismatch: {len(CLASS_LABELS)} vs {model.output_shape[-1]}"
33
+
34
+ def preprocess_image(image, target_size=(224, 224)):
35
+ """Enhanced image preprocessing with validation"""
36
+ if not image:
37
+ raise ValueError("🖼️ No image provided")
38
+
39
  try:
40
+ image = image.convert("RGB").resize(target_size)
 
 
 
41
  return np.array(image) / 255.0
42
  except Exception as e:
43
+ raise RuntimeError(f"🖼️ Image processing failed: {str(e)}")
 
44
 
45
  def get_medical_guidelines(wound_type):
46
+ """Robust API handler with better error reporting"""
47
  headers = {
48
  "Authorization": f"Bearer {OPENROUTER_API_KEY}",
49
  "Content-Type": "application/json",
50
+ "HTTP-Referer": "https://huggingface.co/spaces/MahatirTusher/Wound_Treatment"
 
51
  }
52
 
53
+ prompt = f"""As a medical professional, provide concise guidelines for {wound_type}:
54
+ - First aid steps
55
+ - Precautions
56
+ - When to seek help
57
+ Avoid markdown, use simple language."""
 
 
 
 
 
 
 
58
 
59
  try:
60
+ response = requests.post(
61
+ OPENROUTER_API_URL,
62
+ headers=headers,
63
+ json={
64
+ "model": MODEL_NAME,
65
+ "messages": [{"role": "user", "content": prompt}],
66
+ "temperature": 0.5
67
+ },
68
+ timeout=15
69
+ )
70
 
71
+ response.raise_for_status()
72
+ result = response.json()
73
 
74
+ if not result.get("choices"):
75
+ return f"⚠️ API Error: Unexpected response format"
76
 
77
+ return result["choices"][0]["message"]["content"]
78
 
79
+ except requests.exceptions.RequestException as e:
80
+ return f"🔌 Connection Error: {str(e)}"
 
81
  except Exception as e:
82
+ return f"⚠️ Processing Error: {str(e)}"
 
83
 
84
  def predict(image):
85
+ """Main prediction pipeline with validation"""
86
  try:
87
  # Preprocess image
88
+ processed_img = preprocess_image(image)
89
+ input_tensor = np.expand_dims(processed_img, axis=0)
90
+
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # Make prediction
92
+ predictions = model.predict(input_tensor)[0]
93
+ sorted_indices = np.argsort(predictions)[::-1] # Descending order
94
+
95
+ # Format results
96
+ results = {
97
+ CLASS_LABELS[i]: float(predictions[i])
98
+ for i in sorted_indices[:3] # Top 3 predictions
99
+ }
100
+
101
+ # Get guidelines for top prediction
102
+ top_class = CLASS_LABELS[sorted_indices[0]]
103
+ guidelines = get_medical_guidelines(top_class)
104
 
 
 
 
 
 
 
 
 
 
105
  return results, guidelines
106
+
107
  except Exception as e:
108
+ return {f"🚨 Error": str(e)}, ""
 
109
 
110
+ # Gradio interface configuration
111
  iface = gr.Interface(
112
+ fn=predict,
113
+ inputs=gr.Image(type="pil", label="Upload Wound Image"),
114
  outputs=[
115
+ gr.Label(label="Top Predictions", num_top_classes=3),
116
+ gr.Textbox(label="Treatment Guidelines", lines=5)
117
  ],
118
+ title="Advanced Wound Classification System",
119
+ description="Identifies 18 wound types and provides treatment guidelines",
120
+ examples=["./example_abrasion.jpg", "./example_burn.jpg"] if os.path.exists("example_abrasion.jpg") else None,
121
  allow_flagging="never"
122
  )
123
 
124
+ if __name__ == "__main__":
125
+ iface.launch(server_name="0.0.0.0", server_port=7860)