ashhal commited on
Commit
22af831
Β·
verified Β·
1 Parent(s): b1d04a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -50
app.py CHANGED
@@ -1,74 +1,97 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
  import matplotlib.pyplot as plt
5
- import pdfplumber
6
  import io
 
7
 
8
- # Load MedAlpaca (slow tokenizer to avoid SentencePiece issue)
9
- model_name = "TheBloke/medalpaca-7B-GPTQ"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
11
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
 
12
 
 
13
  def extract_text_from_pdf(pdf_file):
 
14
  try:
15
- with pdfplumber.open(pdf_file) as pdf:
16
- return "\n".join(page.extract_text() for page in pdf.pages if page.extract_text())
 
17
  except Exception as e:
18
- return f"Error reading PDF: {e}"
 
19
 
 
 
 
 
 
 
 
 
20
  def generate_chart(text):
21
- # Dummy abnormal value check
22
- keywords = ['fever', 'pain', 'high', 'low']
23
- is_abnormal = any(word in text.lower() for word in keywords)
24
-
25
- labels = ['Normal', 'Abnormal']
26
- values = [1, 0]
27
- colors = ['green', 'red'] if is_abnormal else ['green', 'gray']
 
 
 
 
 
 
 
 
 
28
 
29
- if is_abnormal:
30
- values = [0.2, 0.8]
31
 
32
  fig, ax = plt.subplots()
33
- bars = ax.bar(labels, values, color=colors)
34
-
35
- for bar in bars:
36
- height = bar.get_height()
37
- ax.text(bar.get_x() + bar.get_width() / 2.0, height, f'{height:.2f}', ha='center', va='bottom')
38
 
39
- ax.set_ylim([0, 1])
40
- ax.set_title("Normal vs Abnormal Indicator")
41
  buf = io.BytesIO()
42
- plt.savefig(buf, format="png")
43
  buf.seek(0)
 
44
  plt.close(fig)
45
- return buf
46
 
47
- def medalpaca_response(input_text, pdf=None):
48
- if pdf:
49
- extracted_text = extract_text_from_pdf(pdf)
50
- input_text += "\n\n" + extracted_text
51
 
52
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
53
- with torch.no_grad():
54
- outputs = model.generate(**inputs, max_new_tokens=200)
55
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
-
57
- chart = generate_chart(input_text)
 
 
58
  return response, chart
59
 
60
- demo = gr.Interface(
61
- fn=medalpaca_response,
62
- inputs=[
63
- gr.Textbox(label="Enter Symptoms or Question"),
64
- gr.File(label="Upload Medical PDF (optional)", type="file")
65
- ],
66
- outputs=[
67
- gr.Textbox(label="MedAlpaca Response"),
68
- gr.Image(label="Normal vs Abnormal Chart")
69
- ],
70
- title="🧠 MedAlpaca Medical Assistant",
71
- description="Ask medical questions or upload a PDF. Get AI-powered responses with visual abnormality flags."
72
- )
73
 
 
74
  demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import matplotlib.pyplot as plt
5
+ import fitz # PyMuPDF
6
  import io
7
+ import base64
8
 
9
+ # Load Mistral model (or any other open-access instruct model)
10
+ model_id = "mistralai/Mistral-7B-Instruct-v0.2"
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_id,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto"
17
+ )
18
 
19
+ # Function to extract text from PDF
20
  def extract_text_from_pdf(pdf_file):
21
+ text = ""
22
  try:
23
+ with fitz.open(stream=pdf_file.read(), filetype="pdf") as doc:
24
+ for page in doc:
25
+ text += page.get_text()
26
  except Exception as e:
27
+ text = f"Error reading PDF: {str(e)}"
28
+ return text
29
 
30
+ # Function to generate answer using model
31
+ def ask_model(input_text, context):
32
+ prompt = f"""[INST] Given the following context, answer the question:\n\nContext:\n{context}\n\nQuestion:\n{input_text} [/INST]"""
33
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
34
+ outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7)
35
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+
37
+ # Create a flag for abnormal values in chart
38
  def generate_chart(text):
39
+ # Dummy logic to extract and visualize values
40
+ lines = [line.strip() for line in text.split('\n') if ':' in line]
41
+ labels, values, flags = [], [], []
42
+
43
+ for line in lines:
44
+ try:
45
+ key, val = line.split(':')
46
+ num = float(val.strip().split()[0])
47
+ labels.append(key.strip())
48
+ values.append(num)
49
+ flags.append(num > 100 or num < 20) # just dummy abnormal range
50
+ except:
51
+ continue
52
+
53
+ if not labels:
54
+ return "No numerical data found for plotting."
55
 
56
+ colors = ['red' if flag else 'green' for flag in flags]
 
57
 
58
  fig, ax = plt.subplots()
59
+ ax.barh(labels, values, color=colors)
60
+ ax.set_xlabel('Values')
61
+ ax.set_title('πŸ“Š Test Results (Green = Normal, Red = Abnormal)')
62
+ plt.tight_layout()
 
63
 
 
 
64
  buf = io.BytesIO()
65
+ plt.savefig(buf, format='png')
66
  buf.seek(0)
67
+ encoded = base64.b64encode(buf.read()).decode('utf-8')
68
  plt.close(fig)
 
69
 
70
+ return f"data:image/png;base64,{encoded}"
 
 
 
71
 
72
+ # Main Gradio interface
73
+ def process_input(user_question, pdf_file):
74
+ if pdf_file is None:
75
+ return "Please upload a medical report (PDF).", None
76
+
77
+ context = extract_text_from_pdf(pdf_file)
78
+ response = ask_model(user_question, context)
79
+ chart = generate_chart(context)
80
  return response, chart
81
 
82
+ # Gradio UI
83
+ with gr.Blocks() as demo:
84
+ gr.Markdown("# 🩺 Medical Report Analyzer\nUpload a PDF report, ask questions, and see abnormalities visualized.")
85
+
86
+ with gr.Row():
87
+ user_question = gr.Textbox(label="Ask a medical question", placeholder="e.g. What does the report say about cholesterol?")
88
+ pdf_input = gr.File(label="Upload medical report (PDF)", file_types=['.pdf'])
89
+
90
+ submit_btn = gr.Button("Analyze")
91
+ output_text = gr.Textbox(label="Answer", lines=8)
92
+ output_img = gr.Image(label="πŸ“Š Chart")
93
+
94
+ submit_btn.click(fn=process_input, inputs=[user_question, pdf_input], outputs=[output_text, output_img])
95
 
96
+ # Launch the app
97
  demo.launch()