kgupta21 commited on
Commit
746ae2b
·
1 Parent(s): eb5c340

local inference page with fixes to gpu with zerogpu + add accelerate for device mapping - removed previous and fixed overall

Browse files
Files changed (2) hide show
  1. app.py +69 -98
  2. requirements.txt +6 -5
app.py CHANGED
@@ -21,59 +21,27 @@ logger = logging.getLogger(__name__)
21
  APP_VERSION = "1.0.0"
22
  logger.info(f"Starting Radiology Teaching App v{APP_VERSION}")
23
 
24
- # Global variables
25
- pipe = None
26
- llm = None
27
- tokenizer = None
28
- device = 0 if torch.cuda.is_available() else "cpu"
29
- logger.info(f"Using device: {device}")
30
-
31
- # Initialize Whisper
32
  MODEL_NAME = "openai/whisper-large-v3-turbo"
33
  BATCH_SIZE = 8
34
  FILE_LIMIT_MB = 5000
 
 
35
 
36
- try:
37
- logger.info("Initializing Whisper model...")
38
- pipe = pipeline(
39
- task="automatic-speech-recognition",
40
- model=MODEL_NAME,
41
- chunk_length_s=30,
42
- device=device,
43
- )
44
- except Exception as e:
45
- logger.error(f"Error initializing Whisper model: {str(e)}")
46
- pipe = None
47
-
48
- # Initialize Llama
49
- try:
50
- logger.info("Initializing Llama model...")
51
  llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
52
-
53
- # Initialize tokenizer first
54
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
55
  tokenizer.use_default_system_prompt = False
56
-
57
- # Initialize model with proper device mapping
58
- if torch.cuda.is_available():
59
- logger.info("Loading Llama model on GPU...")
60
- llm = AutoModelForCausalLM.from_pretrained(
61
- llm_model_id,
62
- torch_dtype=torch.float16,
63
- device_map="auto",
64
- load_in_8bit=True # Use 8-bit quantization to reduce memory usage
65
- )
66
- else:
67
- logger.info("Loading Llama model on CPU...")
68
- llm = AutoModelForCausalLM.from_pretrained(
69
- llm_model_id,
70
- device_map={"": "cpu"},
71
- low_cpu_mem_usage=True
72
- )
73
- except Exception as e:
74
- logger.error(f"Error initializing Llama model: {str(e)}")
75
- llm = None
76
- tokenizer = None
77
 
78
  try:
79
  # Load only 10 rows from the dataset
@@ -133,8 +101,6 @@ def transcribe(inputs, task="transcribe"):
133
  """Transcribe audio using Whisper"""
134
  if inputs is None:
135
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
136
- if pipe is None:
137
- raise gr.Error("Whisper model not initialized properly!")
138
 
139
  try:
140
  logger.info("Transcribing audio...")
@@ -151,61 +117,60 @@ def analyze_with_llama(
151
  ground_truth_impression: str,
152
  max_new_tokens: int = 1024,
153
  temperature: float = 0.6,
 
 
 
154
  ) -> Iterator[str]:
155
  """Analyze transcribed report against ground truth using Llama"""
156
- global llm, tokenizer # Add global declaration
157
-
158
- if llm is None or tokenizer is None:
159
- raise gr.Error("Llama model not initialized properly!")
160
-
161
- try:
162
- task_prompt = f"""You are an expert radiologist. Compare the following transcribed radiology report with the ground truth and provide detailed feedback.
163
 
164
- Transcribed Report:
165
- {transcribed_text}
166
 
167
- Ground Truth Findings:
168
- {ground_truth_findings}
169
-
170
- Ground Truth Impression:
171
- {ground_truth_impression}
172
 
173
- Please analyze:
174
- 1. Accuracy of findings
175
- 2. Completeness of report
176
- 3. Structure and clarity
177
- 4. Areas for improvement
178
-
179
- Provide your analysis in a clear, structured format."""
180
 
181
- conversation = [
182
- {"role": "system", "content": "You are an expert radiologist providing detailed feedback."},
183
- {"role": "user", "content": task_prompt}
184
- ]
 
 
 
185
 
186
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
187
- input_ids = input_ids.to(llm.device)
 
 
188
 
189
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
190
- generate_kwargs = dict(
191
- input_ids=input_ids,
192
- streamer=streamer,
193
- max_new_tokens=max_new_tokens,
194
- do_sample=True,
195
- temperature=temperature,
196
- num_beams=1,
197
- )
198
 
199
- t = Thread(target=llm.generate, kwargs=generate_kwargs)
200
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- outputs = []
203
- for text in streamer:
204
- outputs.append(text)
205
- yield "".join(outputs)
206
- except Exception as e:
207
- logger.error(f"Error in Llama analysis: {str(e)}")
208
- raise gr.Error(f"Analysis failed: {str(e)}")
209
 
210
  def load_random_case(hide_ground_truth):
211
  try:
@@ -279,14 +244,18 @@ with gr.Blocks() as demo:
279
 
280
  # Load case for comparison
281
  load_case_btn = gr.Button("Load Random Case for Comparison")
 
282
  local_ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False)
283
  local_ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False)
284
 
285
  with gr.Column():
286
  # Editable transcription and analysis interface
287
  edited_transcription = gr.Textbox(label="Edit Transcription", lines=10)
288
- temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
 
 
289
  max_tokens_input = gr.Slider(label="Max Tokens", minimum=256, maximum=2048, value=1024, step=128)
 
290
  analyze_btn = gr.Button("Analyze with Llama")
291
  llama_analysis_output = gr.Textbox(label="Llama Analysis Output", lines=15, interactive=False)
292
 
@@ -305,12 +274,11 @@ with gr.Blocks() as demo:
305
  )
306
 
307
  # Load case for local analysis
308
- local_image_display = gr.Image(label="Chest X-ray Image", type="pil") # Add this line
309
  load_case_btn.click(
310
  fn=load_random_case,
311
  inputs=[gr.Checkbox(value=False, visible=False)], # Hidden checkbox for hide_ground_truth
312
  outputs=[
313
- local_image_display, # Update this line
314
  local_ground_truth_findings,
315
  local_ground_truth_impression,
316
  gr.State(), # Hidden state
@@ -326,7 +294,10 @@ with gr.Blocks() as demo:
326
  local_ground_truth_findings,
327
  local_ground_truth_impression,
328
  max_tokens_input,
329
- temperature_input
 
 
 
330
  ],
331
  outputs=llama_analysis_output
332
  )
@@ -370,4 +341,4 @@ with gr.Blocks() as demo:
370
  )
371
 
372
  logger.info("Starting Gradio interface...")
373
- demo.launch()
 
21
  APP_VERSION = "1.0.0"
22
  logger.info(f"Starting Radiology Teaching App v{APP_VERSION}")
23
 
24
+ # Model configuration
 
 
 
 
 
 
 
25
  MODEL_NAME = "openai/whisper-large-v3-turbo"
26
  BATCH_SIZE = 8
27
  FILE_LIMIT_MB = 5000
28
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
+ device = 0 if torch.cuda.is_available() else "cpu"
30
 
31
+ # Initialize the LLM
32
+ if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  llm_model_id = "chuanli11/Llama-3.2-3B-Instruct-uncensored"
34
+ llm = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype=torch.float16, device_map="auto")
 
35
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
36
  tokenizer.use_default_system_prompt = False
37
+
38
+ # Initialize the transcription pipeline
39
+ pipe = pipeline(
40
+ task="automatic-speech-recognition",
41
+ model=MODEL_NAME,
42
+ chunk_length_s=30,
43
+ device=device,
44
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  try:
47
  # Load only 10 rows from the dataset
 
101
  """Transcribe audio using Whisper"""
102
  if inputs is None:
103
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
 
 
104
 
105
  try:
106
  logger.info("Transcribing audio...")
 
117
  ground_truth_impression: str,
118
  max_new_tokens: int = 1024,
119
  temperature: float = 0.6,
120
+ top_p: float = 0.9,
121
+ top_k: int = 50,
122
+ repetition_penalty: float = 1.2,
123
  ) -> Iterator[str]:
124
  """Analyze transcribed report against ground truth using Llama"""
125
+ task_prompt = f"""You are an expert radiologist. Compare the following transcribed radiology report with the ground truth and provide detailed feedback.
 
 
 
 
 
 
126
 
127
+ Transcribed Report:
128
+ {transcribed_text}
129
 
130
+ Ground Truth Findings:
131
+ {ground_truth_findings}
 
 
 
132
 
133
+ Ground Truth Impression:
134
+ {ground_truth_impression}
 
 
 
 
 
135
 
136
+ Please analyze:
137
+ 1. Accuracy of findings
138
+ 2. Completeness of report
139
+ 3. Structure and clarity
140
+ 4. Areas for improvement
141
+
142
+ Provide your analysis in a clear, structured format."""
143
 
144
+ conversation = [
145
+ {"role": "system", "content": "You are an expert radiologist providing detailed feedback."},
146
+ {"role": "user", "content": task_prompt}
147
+ ]
148
 
149
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
150
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
151
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
152
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
153
+ input_ids = input_ids.to(llm.device)
 
 
 
 
154
 
155
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
156
+ generate_kwargs = dict(
157
+ {"input_ids": input_ids},
158
+ streamer=streamer,
159
+ max_new_tokens=max_new_tokens,
160
+ do_sample=True,
161
+ top_p=top_p,
162
+ top_k=top_k,
163
+ temperature=temperature,
164
+ num_beams=1,
165
+ repetition_penalty=repetition_penalty,
166
+ )
167
+ t = Thread(target=llm.generate, kwargs=generate_kwargs)
168
+ t.start()
169
 
170
+ outputs = []
171
+ for text in streamer:
172
+ outputs.append(text)
173
+ yield "".join(outputs)
 
 
 
174
 
175
  def load_random_case(hide_ground_truth):
176
  try:
 
244
 
245
  # Load case for comparison
246
  load_case_btn = gr.Button("Load Random Case for Comparison")
247
+ local_image_display = gr.Image(label="Chest X-ray Image", type="pil")
248
  local_ground_truth_findings = gr.Textbox(label="Ground Truth Findings", lines=5, interactive=False)
249
  local_ground_truth_impression = gr.Textbox(label="Ground Truth Impression", lines=5, interactive=False)
250
 
251
  with gr.Column():
252
  # Editable transcription and analysis interface
253
  edited_transcription = gr.Textbox(label="Edit Transcription", lines=10)
254
+ temperature_input = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, value=0.6, step=0.1)
255
+ top_p_input = gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, value=0.9, step=0.05)
256
+ top_k_input = gr.Slider(label="Top-k", minimum=1, maximum=1000, value=50, step=1)
257
  max_tokens_input = gr.Slider(label="Max Tokens", minimum=256, maximum=2048, value=1024, step=128)
258
+ repetition_penalty_input = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.05)
259
  analyze_btn = gr.Button("Analyze with Llama")
260
  llama_analysis_output = gr.Textbox(label="Llama Analysis Output", lines=15, interactive=False)
261
 
 
274
  )
275
 
276
  # Load case for local analysis
 
277
  load_case_btn.click(
278
  fn=load_random_case,
279
  inputs=[gr.Checkbox(value=False, visible=False)], # Hidden checkbox for hide_ground_truth
280
  outputs=[
281
+ local_image_display,
282
  local_ground_truth_findings,
283
  local_ground_truth_impression,
284
  gr.State(), # Hidden state
 
294
  local_ground_truth_findings,
295
  local_ground_truth_impression,
296
  max_tokens_input,
297
+ temperature_input,
298
+ top_p_input,
299
+ top_k_input,
300
+ repetition_penalty_input
301
  ],
302
  outputs=llama_analysis_output
303
  )
 
341
  )
342
 
343
  logger.info("Starting Gradio interface...")
344
+ demo.queue().launch(ssr_mode=False)
requirements.txt CHANGED
@@ -1,10 +1,11 @@
1
- gradio>=4.16.0
 
 
 
 
2
  pandas>=2.0.0
3
  datasets>=2.15.0
4
  openai>=1.0.0
5
  Pillow>=10.0.0
6
  huggingface-hub>=0.20.0
7
- torch>=2.0.0
8
- transformers>=4.36.0
9
- spaces>=0.19.3
10
- accelerate>=0.27.0
 
1
+ transformers
2
+ gradio
3
+ torch
4
+ accelerate
5
+ SentencePiece
6
  pandas>=2.0.0
7
  datasets>=2.15.0
8
  openai>=1.0.0
9
  Pillow>=10.0.0
10
  huggingface-hub>=0.20.0
11
+ spaces>=0.19.3