blackshadow1 commited on
Commit
38279f0
·
verified ·
1 Parent(s): 8887ad4

added the model code ✅✅

Browse files
Files changed (1) hide show
  1. app.py +57 -33
app.py CHANGED
@@ -5,45 +5,56 @@ from PIL import Image
5
  import logging
6
 
7
  # Configure logging
8
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
- class MedicalAnalyzer:
12
- def __init__(self):
13
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
- self.model = None
15
- self.processor = None
16
-
17
- def load_model(self):
18
- try:
19
- logger.info(f"Loading model on {self.device}...")
20
- self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
- self.model = BlipForConditionalGeneration.from_pretrained(
22
- "Salesforce/blip-image-captioning-base",
23
- torch_dtype=torch.float32
24
- ).to(self.device)
25
- logger.info("Model loaded successfully")
26
- except Exception as e:
27
- logger.error(f"Model loading failed: {e}")
28
- raise RuntimeError(f"Model loading failed. Please check:\n1. Internet connection\n2. Disk space (1GB+ needed)\n3. Try: pip install -r requirements.txt")
29
 
30
- analyzer = MedicalAnalyzer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- def analyze_medical_image(image, question):
 
 
 
 
33
  try:
34
  if not image:
35
  return "⚠️ Please upload a medical image"
36
 
37
- if analyzer.model is None:
38
- analyzer.load_model()
 
 
 
39
 
40
- prompt = f"Question: As a doctor, {question if question else 'describe any abnormalities in this medical image'} Answer:"
41
- inputs = analyzer.processor(image, prompt, return_tensors="pt").to(analyzer.device)
42
 
43
  with torch.no_grad():
44
- outputs = analyzer.model.generate(**inputs, max_new_tokens=100)
45
 
46
- result = analyzer.processor.decode(outputs[0], skip_special_tokens=True)
47
  return result.replace(prompt, "").strip()
48
 
49
  except Exception as e:
@@ -51,17 +62,31 @@ def analyze_medical_image(image, question):
51
  return f"❌ Analysis failed: {str(e)}"
52
 
53
  # Simplified Gradio Interface
54
- with gr.Blocks(title="Medical Image Analyzer") as app:
 
 
 
55
  gr.Markdown("# 🩺 Medical Image Analyzer")
56
 
57
  with gr.Row():
58
  with gr.Column():
59
- image_input = gr.Image(type="pil", label="Upload Scan/X-ray")
60
- question_input = gr.Textbox(label="Clinical Question (optional)")
61
- submit_btn = gr.Button("Analyze")
 
 
 
 
 
 
 
62
 
63
  with gr.Column():
64
- output = gr.Textbox(label="Analysis Result", interactive=False)
 
 
 
 
65
 
66
  submit_btn.click(
67
  analyze_medical_image,
@@ -71,7 +96,6 @@ with gr.Blocks(title="Medical Image Analyzer") as app:
71
 
72
  if __name__ == "__main__":
73
  try:
74
- analyzer.load_model()
75
  app.launch(
76
  server_name="0.0.0.0",
77
  server_port=7860,
 
5
  import logging
6
 
7
  # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11
+ )
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Initialize model components
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MODEL_LOADED = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ try:
19
+ logger.info(f"Loading model on {DEVICE}...")
20
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
+ model = BlipForConditionalGeneration.from_pretrained(
22
+ "Salesforce/blip-image-captioning-base",
23
+ torch_dtype=torch.float32
24
+ ).to(DEVICE)
25
+ MODEL_LOADED = True
26
+ logger.info("Model loaded successfully")
27
+ except Exception as e:
28
+ logger.error(f"Model loading failed: {e}")
29
+ raise RuntimeError(
30
+ "Model failed to load. Please:\n"
31
+ "1. Check internet connection\n"
32
+ "2. Verify at least 1GB disk space\n"
33
+ "3. Try: pip install -r requirements.txt\n"
34
+ "4. Restart your runtime"
35
+ )
36
 
37
+ def analyze_medical_image(image: Image.Image, question: str) -> str:
38
+ """Analyze medical image with optional question"""
39
+ if not MODEL_LOADED:
40
+ return "❌ Model not available. Please check server logs."
41
+
42
  try:
43
  if not image:
44
  return "⚠️ Please upload a medical image"
45
 
46
+ # Medical-focused prompt
47
+ prompt = (
48
+ f"Question: As a doctor, {question if question else 'describe any abnormalities in this medical image'} "
49
+ "Answer professionally:"
50
+ )
51
 
52
+ inputs = processor(image, prompt, return_tensors="pt").to(DEVICE)
 
53
 
54
  with torch.no_grad():
55
+ outputs = model.generate(**inputs, max_new_tokens=100)
56
 
57
+ result = processor.decode(outputs[0], skip_special_tokens=True)
58
  return result.replace(prompt, "").strip()
59
 
60
  except Exception as e:
 
62
  return f"❌ Analysis failed: {str(e)}"
63
 
64
  # Simplified Gradio Interface
65
+ with gr.Blocks(
66
+ title="Medical Image Analyzer",
67
+ css=".gradio-container {max-width: 800px !important}"
68
+ ) as app:
69
  gr.Markdown("# 🩺 Medical Image Analyzer")
70
 
71
  with gr.Row():
72
  with gr.Column():
73
+ image_input = gr.Image(
74
+ type="pil",
75
+ label="Upload Scan/X-ray",
76
+ sources=["upload", "clipboard"]
77
+ )
78
+ question_input = gr.Textbox(
79
+ label="Clinical Question (optional)",
80
+ placeholder="Describe symptoms or ask about findings..."
81
+ )
82
+ submit_btn = gr.Button("Analyze", variant="primary")
83
 
84
  with gr.Column():
85
+ output = gr.Textbox(
86
+ label="Analysis Result",
87
+ interactive=False,
88
+ lines=10
89
+ )
90
 
91
  submit_btn.click(
92
  analyze_medical_image,
 
96
 
97
  if __name__ == "__main__":
98
  try:
 
99
  app.launch(
100
  server_name="0.0.0.0",
101
  server_port=7860,