JohnWeck commited on
Commit
914b995
·
verified ·
1 Parent(s): 201280d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -44
app.py CHANGED
@@ -7,10 +7,15 @@ import torch
7
  import gradio as gr
8
 
9
  from huggingface_hub import hf_hub_download
 
 
10
  def cached_download(*args, **kwargs):
11
  print("Warning: cached_download is deprecated, using hf_hub_download instead.")
12
  return hf_hub_download(*args, **kwargs)
 
 
13
  import sys
 
14
  sys.modules["huggingface_hub.cached_download"] = cached_download
15
 
16
  from diffusers import AutoencoderKL, DDPMScheduler
@@ -20,20 +25,6 @@ from medical_pipeline import MedicalPipeline
20
  from diffusers import DDIMScheduler
21
  from StableDiffusion.Our_Pipe import StableDiffusionPipeline
22
 
23
- def get_random_values(my_dict):
24
- values_list = list(my_dict.values())
25
- num_choices = random.randint(1, len(values_list))
26
- kinds = random.sample(values_list, num_choices)
27
- kind = ''
28
-
29
- for k in kinds:
30
- if kind == '':
31
- kind = k
32
- else:
33
- kind = kind + ',' + k
34
-
35
- return kind
36
-
37
  model_repo_id = "runwayml/stable-diffusion-v1-5"
38
  medsegfactory_id = "JohnWeck/StableDiffusion"
39
  filename = 'checkpoint-300.pth'
@@ -75,42 +66,72 @@ pipe = StableDiffusionPipeline.from_pretrained(
75
 
76
  pipeline = MedicalPipeline(pipe, device)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def generate_image(organ, kind, keys):
79
- # 使用 MedSegFactory 生成图像
80
  image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
81
- plt.subplot(1,2,1)
82
  plt.imshow(image)
83
  plt.axis('off')
84
- plt.subplot(1,2,2)
85
  plt.imshow(label)
86
  plt.axis('off')
87
- plt.savefig('pred.png',bbox_inches='tight', pad_inches = 0)
88
  return "pred.png"
89
 
90
- AMOS2022 = {1: 'liver', 2: 'right kidney', 3: 'spleen', 4: 'pancreas', 5: 'aorta', 6: 'inferior vena cava',
91
- 7: 'right adrenal gland', 8: 'left adrenal gland',
92
- 9: 'gall bladder', 10: 'esophagus', 11: 'stomach', 12: 'duodenum', 13: 'left kidney',
93
- 14: 'bladder', 15: 'prostate'}
94
- ACDC = {1: 'right ventricle', 2: 'myocardium', 3: 'left ventricle'}
95
- LiTS2017 = {1: 'liver', 2: 'liver tumor'}
96
- KiTS2019 = {1: 'kidney', 2: 'kidney tumor'}
97
-
98
- # 预定义的输入案例
99
- examples = [
100
- ['polyp colonoscopy', 'polyp', 'CVC-ClinicDB'],
101
- ['breast ultrasound', 'normal', 'BUSI'],
102
- ['breast ultrasound', 'breast tumor', 'BUSI'],
103
- ]
104
-
105
- # 创建 Gradio 接口
106
- interface = gr.Interface(
107
- fn=generate_image,
108
- inputs=[gr.Dropdown(["polyp colonoscopy"], label="organ"),
109
- gr.Dropdown(["polyp"], label="kind"),
110
- gr.Dropdown(['CVC-ClinicDB'], label="keys")],
111
- outputs=gr.Image(label="Visualization"), # 返回 Matplotlib 渲染的图片
112
- examples=examples # 添加输入案例
113
- )
114
 
115
- # 启动 Gradio 应用
116
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import gradio as gr
8
 
9
  from huggingface_hub import hf_hub_download
10
+
11
+
12
  def cached_download(*args, **kwargs):
13
  print("Warning: cached_download is deprecated, using hf_hub_download instead.")
14
  return hf_hub_download(*args, **kwargs)
15
+
16
+
17
  import sys
18
+
19
  sys.modules["huggingface_hub.cached_download"] = cached_download
20
 
21
  from diffusers import AutoencoderKL, DDPMScheduler
 
25
  from diffusers import DDIMScheduler
26
  from StableDiffusion.Our_Pipe import StableDiffusionPipeline
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  model_repo_id = "runwayml/stable-diffusion-v1-5"
29
  medsegfactory_id = "JohnWeck/StableDiffusion"
30
  filename = 'checkpoint-300.pth'
 
66
 
67
  pipeline = MedicalPipeline(pipe, device)
68
 
69
+ # 定义 keys 与 organ 及 kind 的映射
70
+ keys_to_organ_kind = {
71
+ "CVC-ClinicDB": {
72
+ "organs": ["polyp colonoscopy"],
73
+ "kinds": {"polyp colonoscopy": ["polyp"]}
74
+ },
75
+ "BUSI": {
76
+ "organs": ["breast ultrasound"],
77
+ "kinds": {"breast ultrasound": ["normal", "breast tumor"]}
78
+ },
79
+ "LiTS2017": {
80
+ "organs": ["abdomen CT scans"],
81
+ "kinds": {"abdomen CT scans": ["liver","liver tumor"]}
82
+ },
83
+ "KiTS2019": {
84
+ "organs": ["abdomen CT scans"],
85
+ "kinds": {"abdomen CT scans": ["kidney","kidney tumor"]}
86
+ },
87
+ "ACDC": {
88
+ "organs": ["cardiovascular ventricle mri"],
89
+ "kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]}
90
+ },
91
+ "AMOS2022": {
92
+ "organs": ["abdomen CT scans"],
93
+ "kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava",
94
+ "right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney",
95
+ "bladder", "prostate"]}
96
+ }
97
+ }
98
+
99
+
100
+ def update_organ_and_kind(selected_key):
101
+ organs = keys_to_organ_kind[selected_key]["organs"]
102
+ first_organ = organs[0] # 默认选第一个 organ
103
+ kinds = keys_to_organ_kind[selected_key]["kinds"][first_organ]
104
+ return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds)
105
+
106
+
107
+ def update_kind(selected_key, selected_organ):
108
+ kinds = keys_to_organ_kind[selected_key]["kinds"][selected_organ]
109
+ return gr.update(choices=kinds, value=kinds)
110
+
111
+
112
  def generate_image(organ, kind, keys):
 
113
  image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
114
+ plt.subplot(1, 2, 1)
115
  plt.imshow(image)
116
  plt.axis('off')
117
+ plt.subplot(1, 2, 2)
118
  plt.imshow(label)
119
  plt.axis('off')
120
+ plt.savefig('pred.png', bbox_inches='tight', pad_inches=0)
121
  return "pred.png"
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ with gr.Blocks() as demo:
125
+ keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB")
126
+ organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ")
127
+ kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind")
128
+
129
+ keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox])
130
+ organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox)
131
+
132
+ generate_button = gr.Button("Generate Image")
133
+ output_image = gr.Image(label="Visualization")
134
+
135
+ generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image)
136
+
137
+ demo.launch()