mgbam commited on
Commit
bb16e72
·
verified ·
1 Parent(s): 2db0e0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -161
app.py CHANGED
@@ -1,187 +1,90 @@
1
  import torch
2
- from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
3
  from PIL import Image
4
- from diffusers.models import AutoencoderKL
5
  import numpy as np
6
  import gradio as gr
7
- import warnings
8
 
9
- # Suppress unnecessary warnings
10
- warnings.filterwarnings("ignore")
 
11
 
12
- # Force CPU usage
13
- device = torch.device("cpu")
14
- print("Using device: cpu")
15
-
16
- # Medical-specific model configuration
17
- MEDICAL_MODEL_CONFIG = {
18
- "model_path": "deepseek-ai/JanusFlow-1.3B",
19
- "vae_path": "stabilityai/sdxl-vae",
20
- "max_analysis_length": 512,
21
- "min_image_size": 512,
22
- "max_image_size": 1024
23
- }
24
-
25
- # Load medical-optimized model and processor
26
- try:
27
- vl_chat_processor = VLChatProcessor.from_pretrained(
28
- MEDICAL_MODEL_CONFIG["model_path"],
29
- medical_mode=True
30
- )
31
- tokenizer = vl_chat_processor.tokenizer
32
-
33
- vl_gpt = MultiModalityCausalLM.from_pretrained(
34
- MEDICAL_MODEL_CONFIG["model_path"],
35
- medical_weights=True
36
- ).to(device).eval()
37
-
38
- # Load medical-optimized VAE
39
- vae = AutoencoderKL.from_pretrained(
40
- MEDICAL_MODEL_CONFIG["vae_path"],
41
- subfolder="vae",
42
- medical_config=True
43
- ).to(device).eval()
44
 
45
- except Exception as e:
46
- print(f"Error loading medical models: {str(e)}")
47
- raise
48
 
49
  # Medical image analysis function
50
- @torch.inference_mode()
51
- def medical_image_analysis(image, question, seed=42, top_p=0.95, temperature=0.1):
52
- torch.manual_seed(seed)
53
- np.random.seed(seed)
54
-
55
  try:
56
- # Medical image preprocessing
 
 
 
 
57
  if isinstance(image, np.ndarray):
58
  image = Image.fromarray(image).convert("RGB")
59
 
60
- # Medical conversation template
61
- conversation = [{
62
- "role": "Radiologist",
63
- "content": f"<medical_image>\n{question}",
64
- "images": [image],
65
- }]
66
-
67
- inputs = vl_chat_processor(
68
- conversations=conversation,
69
  images=[image],
70
- medical_mode=True,
71
- max_length=MEDICAL_MODEL_CONFIG["max_analysis_length"]
72
  ).to(device)
73
-
74
- outputs = vl_gpt.generate(
75
- inputs_embeds=inputs.inputs_embeds,
 
76
  attention_mask=inputs.attention_mask,
77
- max_new_tokens=MEDICAL_MODEL_CONFIG["max_analysis_length"],
78
  temperature=temperature,
79
- top_p=top_p,
80
- medical_context=True
81
  )
82
-
83
- report = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
- return clean_medical_report(report)
85
-
86
  except Exception as e:
87
- return f"Medical analysis error: {str(e)}"
88
 
89
- # Medical image generation function
90
- @torch.inference_mode()
91
- def generate_medical_image(prompt, seed=12345, guidance=5, steps=30):
92
- torch.manual_seed(seed)
93
 
94
- try:
95
- # Medical prompt validation
96
- if not validate_medical_prompt(prompt):
97
- return ["Invalid medical prompt - please provide specific anatomical details"]
98
-
99
- inputs = vl_chat_processor.encode_medical_prompt(
100
- prompt,
101
- max_length=MEDICAL_MODEL_CONFIG["max_analysis_length"],
102
- device=device
103
- )
104
-
105
- # Medical image generation pipeline
106
- with torch.autocast(device.type):
107
- images = vae.decode_latents(
108
- vl_gpt.generate_medical_latents(
109
- inputs,
110
- guidance_scale=guidance,
111
- num_inference_steps=steps
112
- )
113
- )
114
-
115
- return postprocess_medical_images(images)
116
-
117
- except Exception as e:
118
- return [f"Medical imaging error: {str(e)}"]
119
-
120
- # Helper functions
121
- def validate_medical_prompt(prompt):
122
- medical_terms = ["MRI", "CT", "X-ray", "ultrasound", "histology", "anatomy"]
123
- return any(term in prompt.lower() for term in medical_terms)
124
-
125
- def postprocess_medical_images(images):
126
- processed = []
127
- for img in images:
128
- img = Image.fromarray(img).resize(
129
- (MEDICAL_MODEL_CONFIG["min_image_size"],
130
- MEDICAL_MODEL_CONFIG["min_image_size"]),
131
- Image.LANCZOS
132
- )
133
- processed.append(img)
134
- return processed
135
-
136
- def clean_medical_report(text):
137
- return text.replace("##MEDICAL_REPORT##", "").strip()
138
-
139
- # Medical-grade interface
140
- with gr.Blocks(title="Medical Imaging AI Assistant", theme="soft") as demo:
141
- gr.Markdown("""# Medical Imaging Analysis & Generation System
142
- **Certified for diagnostic support use**""")
143
-
144
- with gr.Tab("Radiology Analysis"):
145
  with gr.Row():
146
- gr.Markdown("## Patient Imaging Analysis")
147
- with gr.Column():
148
- medical_image = gr.Image(label="DICOM/Medical Image", type="pil")
149
- clinical_query = gr.Textbox(label="Clinical Question")
150
- analysis_btn = gr.Button("Generate Report", variant="primary")
151
-
152
- report_output = gr.Textbox(label="Clinical Findings", interactive=False)
153
-
154
- with gr.Tab("Diagnostic Imaging Generation"):
155
- with gr.Row():
156
- gr.Markdown("## Synthetic Medical Image Generation")
157
- with gr.Column():
158
- imaging_protocol = gr.Textbox(label="Imaging Protocol")
159
- generate_btn = gr.Button("Generate Study", variant="primary")
160
-
161
- study_gallery = gr.Gallery(
162
- label="Generated Images",
163
- columns=2,
164
- height=MEDICAL_MODEL_CONFIG["max_image_size"]
165
  )
166
-
167
- # Medical workflow connections
168
- analysis_btn.click(
169
- medical_image_analysis,
170
- inputs=[medical_image, clinical_query],
171
- outputs=report_output
172
- )
173
 
174
- generate_btn.click(
175
- generate_medical_image,
176
- inputs=[imaging_protocol],
177
- outputs=study_gallery
178
  )
179
 
180
- # Launch with medical safety protocols
181
- demo.launch(
182
- server_name="0.0.0.0",
183
- server_port=7860,
184
- enable_queue=True,
185
- max_threads=2,
186
- show_error=True
187
- )
 
1
  import torch
2
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
3
  from PIL import Image
4
+ from diffusers import AutoencoderKL
5
  import numpy as np
6
  import gradio as gr
 
7
 
8
+ # Configure device
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"Using device: {device}")
11
 
12
+ # Initialize medical imaging components
13
+ def load_medical_models():
14
+ try:
15
+ # Load processor and tokenizer
16
+ processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B")
17
+
18
+ # Load base model
19
+ model = MultiModalityCausalLM.from_pretrained(
20
+ "deepseek-ai/Janus-1.3B",
21
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
22
+ ).to(device).eval()
23
+
24
+ # Load VAE for image processing
25
+ vae = AutoencoderKL.from_pretrained(
26
+ "stabilityai/sdxl-vae",
27
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
28
+ ).to(device).eval()
29
+
30
+ return processor, model, vae
31
+ except Exception as e:
32
+ print(f"Error loading models: {str(e)}")
33
+ raise
 
 
 
 
 
 
 
 
 
 
34
 
35
+ processor, model, vae = load_medical_models()
 
 
36
 
37
  # Medical image analysis function
38
+ def medical_analysis(image, question, seed=42, top_p=0.95, temperature=0.1):
 
 
 
 
39
  try:
40
+ # Set random seed for reproducibility
41
+ torch.manual_seed(seed)
42
+ np.random.seed(seed)
43
+
44
+ # Prepare inputs
45
  if isinstance(image, np.ndarray):
46
  image = Image.fromarray(image).convert("RGB")
47
 
48
+ inputs = processor(
49
+ text=question,
 
 
 
 
 
 
 
50
  images=[image],
51
+ return_tensors="pt"
 
52
  ).to(device)
53
+
54
+ # Generate analysis
55
+ outputs = model.generate(
56
+ inputs.input_ids,
57
  attention_mask=inputs.attention_mask,
58
+ max_new_tokens=512,
59
  temperature=temperature,
60
+ top_p=top_p
 
61
  )
62
+
63
+ return processor.decode(outputs[0], skip_special_tokens=True)
 
 
64
  except Exception as e:
65
+ return f"Analysis error: {str(e)}"
66
 
67
+ # Medical interface
68
+ with gr.Blocks(title="Medical Imaging Assistant") as demo:
69
+ gr.Markdown("# Medical Imaging AI Assistant")
 
70
 
71
+ with gr.Tab("Analysis"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Row():
73
+ med_image = gr.Image(label="Input Image", type="pil")
74
+ med_question = gr.Textbox(label="Clinical Query")
75
+ analysis_output = gr.Textbox(label="Findings")
76
+ gr.Examples(
77
+ examples=[
78
+ ["ultrasound_sample.jpg", "Identify any abnormalities in this ultrasound"],
79
+ ["xray_sample.jpg", "Describe the bone structure visible in this X-ray"]
80
+ ],
81
+ inputs=[med_image, med_question]
 
 
 
 
 
 
 
 
 
 
82
  )
 
 
 
 
 
 
 
83
 
84
+ med_question.submit(
85
+ medical_analysis,
86
+ inputs=[med_image, med_question],
87
+ outputs=analysis_output
88
  )
89
 
90
+ demo.launch()