blackshadow1 commited on
Commit
0bdcda7
·
verified ·
1 Parent(s): 3e1378a

added the code ✅✅

Browse files
Files changed (1) hide show
  1. app.py +44 -37
app.py CHANGED
@@ -5,54 +5,52 @@ from PIL import Image
5
  import logging
6
 
7
  # Configure logging
8
- logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- # Load smaller BLIP model (more reliable)
12
- def load_model():
13
- try:
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- logger.info(f"Loading model on {device}...")
16
-
17
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
18
- model = BlipForConditionalGeneration.from_pretrained(
19
- "Salesforce/blip-image-captioning-base",
20
- torch_dtype=torch.float32 # More stable than float16
21
- ).to(device)
22
 
23
- logger.info("Model loaded successfully")
24
- return model, processor, device
25
- except Exception as e:
26
- logger.error(f"Model loading failed: {e}")
27
- raise
 
 
 
 
 
 
 
28
 
29
- try:
30
- model, processor, device = load_model()
31
- except Exception as e:
32
- raise gr.Error(f"Failed to load model. Please check:\n1. Internet connection\n2. Disk space (5GB+ needed)\n3. Try again later")
33
 
34
- def analyze_medical_image(image, question, history=[]):
35
  try:
36
  if not image:
37
- return "⚠️ Please upload a medical image", history
 
 
 
38
 
39
- # Medical-focused prompt
40
  prompt = f"Question: As a doctor, {question if question else 'describe any abnormalities in this medical image'} Answer:"
 
41
 
42
- inputs = processor(image, prompt, return_tensors="pt").to(device)
43
  with torch.no_grad():
44
- outputs = model.generate(**inputs, max_new_tokens=100)
45
 
46
- result = processor.decode(outputs[0], skip_special_tokens=True)
47
- result = result.replace(prompt, "").strip()
48
-
49
- return result, history + [(question, result)]
50
 
51
  except Exception as e:
52
- logger.error(f"Error: {e}")
53
- return f"❌ Analysis failed: {str(e)}", history
54
 
55
- # Simple Gradio Interface
56
  with gr.Blocks(title="Medical Image Analyzer") as app:
57
  gr.Markdown("# 🩺 Medical Image Analyzer")
58
 
@@ -63,13 +61,22 @@ with gr.Blocks(title="Medical Image Analyzer") as app:
63
  submit_btn = gr.Button("Analyze")
64
 
65
  with gr.Column():
66
- chatbot = gr.Chatbot(label="Analysis Report")
67
 
68
  submit_btn.click(
69
  analyze_medical_image,
70
- [image_input, question_input, chatbot],
71
- [chatbot, chatbot]
72
  )
73
 
74
  if __name__ == "__main__":
75
- app.launch(server_name="0.0.0.0", share=False)
 
 
 
 
 
 
 
 
 
 
5
  import logging
6
 
7
  # Configure logging
8
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
9
  logger = logging.getLogger(__name__)
10
 
11
+ class MedicalAnalyzer:
12
+ def __init__(self):
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ self.model = None
15
+ self.processor = None
 
 
 
 
 
 
16
 
17
+ def load_model(self):
18
+ try:
19
+ logger.info(f"Loading model on {self.device}...")
20
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
+ self.model = BlipForConditionalGeneration.from_pretrained(
22
+ "Salesforce/blip-image-captioning-base",
23
+ torch_dtype=torch.float32
24
+ ).to(self.device)
25
+ logger.info("Model loaded successfully")
26
+ except Exception as e:
27
+ logger.error(f"Model loading failed: {e}")
28
+ raise RuntimeError(f"Model loading failed. Please check:\n1. Internet connection\n2. Disk space (1GB+ needed)\n3. Try: pip install -r requirements.txt")
29
 
30
+ analyzer = MedicalAnalyzer()
 
 
 
31
 
32
+ def analyze_medical_image(image, question):
33
  try:
34
  if not image:
35
+ return "⚠️ Please upload a medical image"
36
+
37
+ if analyzer.model is None:
38
+ analyzer.load_model()
39
 
 
40
  prompt = f"Question: As a doctor, {question if question else 'describe any abnormalities in this medical image'} Answer:"
41
+ inputs = analyzer.processor(image, prompt, return_tensors="pt").to(analyzer.device)
42
 
 
43
  with torch.no_grad():
44
+ outputs = analyzer.model.generate(**inputs, max_new_tokens=100)
45
 
46
+ result = analyzer.processor.decode(outputs[0], skip_special_tokens=True)
47
+ return result.replace(prompt, "").strip()
 
 
48
 
49
  except Exception as e:
50
+ logger.error(f"Analysis error: {e}")
51
+ return f"❌ Analysis failed: {str(e)}"
52
 
53
+ # Simplified Gradio Interface
54
  with gr.Blocks(title="Medical Image Analyzer") as app:
55
  gr.Markdown("# 🩺 Medical Image Analyzer")
56
 
 
61
  submit_btn = gr.Button("Analyze")
62
 
63
  with gr.Column():
64
+ output = gr.Textbox(label="Analysis Result", interactive=False)
65
 
66
  submit_btn.click(
67
  analyze_medical_image,
68
+ inputs=[image_input, question_input],
69
+ outputs=output
70
  )
71
 
72
  if __name__ == "__main__":
73
+ try:
74
+ analyzer.load_model()
75
+ app.launch(
76
+ server_name="0.0.0.0",
77
+ server_port=7860,
78
+ show_error=True
79
+ )
80
+ except Exception as e:
81
+ logger.error(f"Application failed: {e}")
82
+ raise gr.Error("Application failed to start. Please check the logs.")