kgupta21 commited on
Commit
41cd3de
·
1 Parent(s): 45177a3

local inference page

Browse files
Files changed (3) hide show
  1. .gitignore +44 -0
  2. app.py +193 -14
  3. requirements.txt +4 -1
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ ENV/
26
+ env/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Logs
35
+ *.log
36
+
37
+ # Local development
38
+ .env
39
+ .env.local
40
+ .env.*.local
41
+
42
+ # Misc
43
+ .DS_Store
44
+ Thumbs.db
app.py CHANGED
@@ -6,6 +6,11 @@ from PIL import Image
6
  import io
7
  import base64
8
  import logging
 
 
 
 
 
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
@@ -15,6 +20,40 @@ logger = logging.getLogger(__name__)
15
  APP_VERSION = "1.0.0"
16
  logger.info(f"Starting Radiology Teaching App v{APP_VERSION}")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
  # Load only 10 rows from the dataset
20
  logger.info("Loading MIMIC-CXR dataset...")
@@ -68,6 +107,81 @@ def analyze_report(user_findings, ground_truth_findings, ground_truth_impression
68
  logger.error(f"Error in report analysis: {str(e)}")
69
  return f"Error analyzing report: {str(e)}"
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def load_random_case(hide_ground_truth):
72
  try:
73
  # Randomly select a case from our dataset
@@ -112,21 +226,86 @@ with gr.Blocks() as demo:
112
  actual_findings_state = gr.State("")
113
  actual_impression_state = gr.State("")
114
 
115
- with gr.Row():
116
- with gr.Column():
117
- image_display = gr.Image(label="Chest X-ray Image", type="pil")
118
- api_key_input = gr.Textbox(label="DeepSeek API Key", type="password")
119
- hide_truth = gr.Checkbox(label="Hide Ground Truth", value=False)
120
- load_btn = gr.Button("Load Random Case")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- with gr.Column():
123
- user_findings_input = gr.Textbox(label="Your Findings", lines=10, placeholder="Type or dictate your findings here...")
124
- ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False)
125
- ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False)
126
- analysis_output = gr.Textbox(label="Analysis and Feedback", lines=10, interactive=False)
127
- submit_btn = gr.Button("Submit Report")
128
-
129
- # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  load_btn.click(
131
  fn=load_random_case,
132
  inputs=[hide_truth],
 
6
  import io
7
  import base64
8
  import logging
9
+ import torch
10
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+ from threading import Thread
12
+ from typing import Iterator
13
+ import os
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
 
20
  APP_VERSION = "1.0.0"
21
  logger.info(f"Starting Radiology Teaching App v{APP_VERSION}")
22
 
23
+ # Initialize models for local inference
24
+ device = 0 if torch.cuda.is_available() else "cpu"
25
+ logger.info(f"Using device: {device}")
26
+
27
+ # Initialize Whisper
28
+ MODEL_NAME = "openai/whisper-large-v3-turbo"
29
+ BATCH_SIZE = 8
30
+ FILE_LIMIT_MB = 5000
31
+
32
+ try:
33
+ logger.info("Initializing Whisper model...")
34
+ pipe = pipeline(
35
+ task="automatic-speech-recognition",
36
+ model=MODEL_NAME,
37
+ chunk_length_s=30,
38
+ device=device,
39
+ )
40
+ except Exception as e:
41
+ logger.error(f"Error initializing Whisper model: {str(e)}")
42
+ pipe = None
43
+
44
+ # Initialize Llama
45
+ try:
46
+ logger.info("Initializing Llama model...")
47
+ if torch.cuda.is_available():
48
+ llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
49
+ llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto")
50
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
51
+ tokenizer.use_default_system_prompt = False
52
+ except Exception as e:
53
+ logger.error(f"Error initializing Llama model: {str(e)}")
54
+ llm = None
55
+ tokenizer = None
56
+
57
  try:
58
  # Load only 10 rows from the dataset
59
  logger.info("Loading MIMIC-CXR dataset...")
 
107
  logger.error(f"Error in report analysis: {str(e)}")
108
  return f"Error analyzing report: {str(e)}"
109
 
110
+ def transcribe(inputs, task="transcribe"):
111
+ """Transcribe audio using Whisper"""
112
+ if inputs is None:
113
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
114
+ if pipe is None:
115
+ raise gr.Error("Whisper model not initialized properly!")
116
+
117
+ try:
118
+ logger.info("Transcribing audio...")
119
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
120
+ return text
121
+ except Exception as e:
122
+ logger.error(f"Error in transcription: {str(e)}")
123
+ raise gr.Error(f"Transcription failed: {str(e)}")
124
+
125
+ def analyze_with_llama(
126
+ transcribed_text: str,
127
+ ground_truth_findings: str,
128
+ ground_truth_impression: str,
129
+ max_new_tokens: int = 1024,
130
+ temperature: float = 0.6,
131
+ ) -> Iterator[str]:
132
+ """Analyze transcribed report against ground truth using Llama"""
133
+ if llm is None or tokenizer is None:
134
+ raise gr.Error("Llama model not initialized properly!")
135
+
136
+ try:
137
+ task_prompt = f"""You are an expert radiologist. Compare the following transcribed radiology report with the ground truth and provide detailed feedback.
138
+
139
+ Transcribed Report:
140
+ {transcribed_text}
141
+
142
+ Ground Truth Findings:
143
+ {ground_truth_findings}
144
+
145
+ Ground Truth Impression:
146
+ {ground_truth_impression}
147
+
148
+ Please analyze:
149
+ 1. Accuracy of findings
150
+ 2. Completeness of report
151
+ 3. Structure and clarity
152
+ 4. Areas for improvement
153
+
154
+ Provide your analysis in a clear, structured format."""
155
+
156
+ conversation = [
157
+ {"role": "system", "content": "You are an expert radiologist providing detailed feedback."},
158
+ {"role": "user", "content": task_prompt}
159
+ ]
160
+
161
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
162
+ input_ids = input_ids.to(llm.device)
163
+
164
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
165
+ generate_kwargs = dict(
166
+ input_ids=input_ids,
167
+ streamer=streamer,
168
+ max_new_tokens=max_new_tokens,
169
+ do_sample=True,
170
+ temperature=temperature,
171
+ num_beams=1,
172
+ )
173
+
174
+ t = Thread(target=llm.generate, kwargs=generate_kwargs)
175
+ t.start()
176
+
177
+ outputs = []
178
+ for text in streamer:
179
+ outputs.append(text)
180
+ yield "".join(outputs)
181
+ except Exception as e:
182
+ logger.error(f"Error in Llama analysis: {str(e)}")
183
+ raise gr.Error(f"Analysis failed: {str(e)}")
184
+
185
  def load_random_case(hide_ground_truth):
186
  try:
187
  # Randomly select a case from our dataset
 
226
  actual_findings_state = gr.State("")
227
  actual_impression_state = gr.State("")
228
 
229
+ with gr.Tab("DeepSeek Analysis"):
230
+ with gr.Row():
231
+ with gr.Column():
232
+ image_display = gr.Image(label="Chest X-ray Image", type="pil")
233
+ api_key_input = gr.Textbox(label="DeepSeek API Key", type="password")
234
+ hide_truth = gr.Checkbox(label="Hide Ground Truth", value=False)
235
+ load_btn = gr.Button("Load Random Case")
236
+
237
+ with gr.Column():
238
+ user_findings_input = gr.Textbox(label="Your Findings", lines=10, placeholder="Type or dictate your findings here...")
239
+ ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False)
240
+ ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False)
241
+ analysis_output = gr.Textbox(label="Analysis and Feedback", lines=10, interactive=False)
242
+ submit_btn = gr.Button("Submit Report")
243
+
244
+ with gr.Tab("Local Inference"):
245
+ gr.Markdown("### Use Local Models for Transcription and Analysis")
246
+
247
+ with gr.Row():
248
+ with gr.Column():
249
+ # Transcription Interface
250
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or Upload Audio")
251
+ task_input = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
252
+ transcribe_button = gr.Button("Transcribe Audio")
253
+ transcription_output = gr.Textbox(label="Transcription Output", lines=5)
254
+
255
+ # Load case for comparison
256
+ load_case_btn = gr.Button("Load Random Case for Comparison")
257
+ local_ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False)
258
+ local_ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False)
259
+
260
+ with gr.Column():
261
+ # Editable transcription and analysis interface
262
+ edited_transcription = gr.Textbox(label="Edit Transcription", lines=10)
263
+ temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
264
+ max_tokens_input = gr.Slider(label="Max Tokens", minimum=256, maximum=2048, value=1024, step=128)
265
+ analyze_btn = gr.Button("Analyze with Llama")
266
+ llama_analysis_output = gr.Textbox(label="Llama Analysis Output", lines=15, interactive=False)
267
+
268
+ # Event handlers for Local Inference tab
269
+ transcribe_button.click(
270
+ fn=transcribe,
271
+ inputs=[audio_input, task_input],
272
+ outputs=transcription_output
273
+ )
274
+
275
+ # Copy transcription to editable box
276
+ transcription_output.change(
277
+ fn=lambda x: x,
278
+ inputs=[transcription_output],
279
+ outputs=[edited_transcription]
280
+ )
281
 
282
+ # Load case for local analysis
283
+ load_case_btn.click(
284
+ fn=load_random_case,
285
+ inputs=[gr.Checkbox(value=False, visible=False)], # Hidden checkbox for hide_ground_truth
286
+ outputs=[
287
+ gr.Image(visible=False), # Hidden image output
288
+ local_ground_truth_findings,
289
+ local_ground_truth_impression,
290
+ gr.State(), # Hidden state
291
+ gr.State() # Hidden state
292
+ ]
293
+ )
294
+
295
+ # Analyze with Llama
296
+ analyze_btn.click(
297
+ fn=analyze_with_llama,
298
+ inputs=[
299
+ edited_transcription,
300
+ local_ground_truth_findings,
301
+ local_ground_truth_impression,
302
+ max_tokens_input,
303
+ temperature_input
304
+ ],
305
+ outputs=llama_analysis_output
306
+ )
307
+
308
+ # Event handlers for DeepSeek Analysis tab
309
  load_btn.click(
310
  fn=load_random_case,
311
  inputs=[hide_truth],
requirements.txt CHANGED
@@ -3,4 +3,7 @@ pandas>=2.0.0
3
  datasets>=2.15.0
4
  openai>=1.0.0
5
  Pillow>=10.0.0
6
- huggingface-hub>=0.20.0
 
 
 
 
3
  datasets>=2.15.0
4
  openai>=1.0.0
5
  Pillow>=10.0.0
6
+ huggingface-hub>=0.20.0
7
+ torch>=2.0.0
8
+ transformers>=4.36.0
9
+ spaces>=0.19.3