mgbam commited on
Commit
ab9c414
·
verified ·
1 Parent(s): e259a47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -31
app.py CHANGED
@@ -5,23 +5,23 @@ from diffusers import AutoencoderKL
5
  import numpy as np
6
  import gradio as gr
7
 
8
- # Configure device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- print(f"Using device: {device}")
 
11
 
12
  # Initialize medical imaging components
13
  def load_medical_models():
14
  try:
15
- # Load processor and tokenizer
16
  processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B")
17
 
18
- # Load base model
19
  model = MultiModalityCausalLM.from_pretrained(
20
  "deepseek-ai/Janus-1.3B",
21
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
 
 
22
  ).to(device).eval()
23
 
24
- # Load VAE for image processing
25
  vae = AutoencoderKL.from_pretrained(
26
  "stabilityai/sdxl-vae",
27
  torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
@@ -29,62 +29,61 @@ def load_medical_models():
29
 
30
  return processor, model, vae
31
  except Exception as e:
32
- print(f"Error loading models: {str(e)}")
33
  raise
34
 
35
  processor, model, vae = load_medical_models()
36
 
37
- # Medical image analysis function
38
- def medical_analysis(image, question, seed=42, top_p=0.95, temperature=0.1):
39
  try:
40
- # Set random seed for reproducibility
41
  torch.manual_seed(seed)
42
  np.random.seed(seed)
43
 
44
- # Prepare inputs
45
  if isinstance(image, np.ndarray):
46
  image = Image.fromarray(image).convert("RGB")
47
 
48
  inputs = processor(
49
- text=question,
50
  images=[image],
51
  return_tensors="pt"
52
  ).to(device)
53
 
54
- # Generate analysis
55
  outputs = model.generate(
56
  inputs.input_ids,
57
  attention_mask=inputs.attention_mask,
58
  max_new_tokens=512,
59
- temperature=temperature,
60
- top_p=top_p
 
61
  )
62
 
63
  return processor.decode(outputs[0], skip_special_tokens=True)
64
  except Exception as e:
65
- return f"Analysis error: {str(e)}"
66
 
67
  # Medical interface
68
- with gr.Blocks(title="Medical Imaging Assistant") as demo:
69
- gr.Markdown("# Medical Imaging AI Assistant")
 
70
 
71
- with gr.Tab("Analysis"):
72
  with gr.Row():
73
- med_image = gr.Image(label="Input Image", type="pil")
74
- med_question = gr.Textbox(label="Clinical Query")
75
- analysis_output = gr.Textbox(label="Findings")
76
- gr.Examples(
77
- examples=[
78
- ["ultrasound_sample.jpg", "Identify any abnormalities in this ultrasound"],
79
- ["xray_sample.jpg", "Describe the bone structure visible in this X-ray"]
80
- ],
81
- inputs=[med_image, med_question]
82
- )
83
 
84
  med_question.submit(
85
  medical_analysis,
86
  inputs=[med_image, med_question],
87
- outputs=analysis_output
 
 
 
 
 
88
  )
89
 
90
- demo.launch()
 
5
  import numpy as np
6
  import gradio as gr
7
 
8
+ # Configure device and attention implementation
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ attn_implementation = "flash_attention_2" if device == "cuda" else "eager"
11
+ print(f"Using device: {device} with {attn_implementation}")
12
 
13
  # Initialize medical imaging components
14
  def load_medical_models():
15
  try:
 
16
  processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B")
17
 
 
18
  model = MultiModalityCausalLM.from_pretrained(
19
  "deepseek-ai/Janus-1.3B",
20
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
21
+ attn_implementation=attn_implementation,
22
+ use_flash_attention_2=(attn_implementation == "flash_attention_2")
23
  ).to(device).eval()
24
 
 
25
  vae = AutoencoderKL.from_pretrained(
26
  "stabilityai/sdxl-vae",
27
  torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
 
29
 
30
  return processor, model, vae
31
  except Exception as e:
32
+ print(f"Error loading medical models: {str(e)}")
33
  raise
34
 
35
  processor, model, vae = load_medical_models()
36
 
37
+ # Medical image analysis function with attention control
38
+ def medical_analysis(image, question, seed=42):
39
  try:
 
40
  torch.manual_seed(seed)
41
  np.random.seed(seed)
42
 
 
43
  if isinstance(image, np.ndarray):
44
  image = Image.fromarray(image).convert("RGB")
45
 
46
  inputs = processor(
47
+ text=f"<medical_query>{question}</medical_query>",
48
  images=[image],
49
  return_tensors="pt"
50
  ).to(device)
51
 
 
52
  outputs = model.generate(
53
  inputs.input_ids,
54
  attention_mask=inputs.attention_mask,
55
  max_new_tokens=512,
56
+ temperature=0.1,
57
+ top_p=0.95,
58
+ pad_token_id=processor.tokenizer.eos_token_id
59
  )
60
 
61
  return processor.decode(outputs[0], skip_special_tokens=True)
62
  except Exception as e:
63
+ return f"Radiology analysis error: {str(e)}"
64
 
65
  # Medical interface
66
+ with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo:
67
+ gr.Markdown("""# AI Radiology Assistant
68
+ **CT/MRI/X-ray Analysis System**""")
69
 
70
+ with gr.Tab("Diagnostic Imaging"):
71
  with gr.Row():
72
+ med_image = gr.Image(label="DICOM Image", type="pil")
73
+ med_question = gr.Textbox(label="Clinical Query",
74
+ placeholder="Describe findings in this CT scan...")
75
+ analysis_btn = gr.Button("Analyze", variant="primary")
76
+ report_output = gr.Textbox(label="Radiology Report", interactive=False)
 
 
 
 
 
77
 
78
  med_question.submit(
79
  medical_analysis,
80
  inputs=[med_image, med_question],
81
+ outputs=report_output
82
+ )
83
+ analysis_btn.click(
84
+ medical_analysis,
85
+ inputs=[med_image, med_question],
86
+ outputs=report_output
87
  )
88
 
89
+ demo.launch(server_name="0.0.0.0", server_port=7860)