Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,10 +7,15 @@ import torch
|
|
7 |
import gradio as gr
|
8 |
|
9 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
10 |
def cached_download(*args, **kwargs):
|
11 |
print("Warning: cached_download is deprecated, using hf_hub_download instead.")
|
12 |
return hf_hub_download(*args, **kwargs)
|
|
|
|
|
13 |
import sys
|
|
|
14 |
sys.modules["huggingface_hub.cached_download"] = cached_download
|
15 |
|
16 |
from diffusers import AutoencoderKL, DDPMScheduler
|
@@ -20,20 +25,6 @@ from medical_pipeline import MedicalPipeline
|
|
20 |
from diffusers import DDIMScheduler
|
21 |
from StableDiffusion.Our_Pipe import StableDiffusionPipeline
|
22 |
|
23 |
-
def get_random_values(my_dict):
|
24 |
-
values_list = list(my_dict.values())
|
25 |
-
num_choices = random.randint(1, len(values_list))
|
26 |
-
kinds = random.sample(values_list, num_choices)
|
27 |
-
kind = ''
|
28 |
-
|
29 |
-
for k in kinds:
|
30 |
-
if kind == '':
|
31 |
-
kind = k
|
32 |
-
else:
|
33 |
-
kind = kind + ',' + k
|
34 |
-
|
35 |
-
return kind
|
36 |
-
|
37 |
model_repo_id = "runwayml/stable-diffusion-v1-5"
|
38 |
medsegfactory_id = "JohnWeck/StableDiffusion"
|
39 |
filename = 'checkpoint-300.pth'
|
@@ -75,42 +66,72 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|
75 |
|
76 |
pipeline = MedicalPipeline(pipe, device)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def generate_image(organ, kind, keys):
|
79 |
-
# 使用 MedSegFactory 生成图像
|
80 |
image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
|
81 |
-
plt.subplot(1,2,1)
|
82 |
plt.imshow(image)
|
83 |
plt.axis('off')
|
84 |
-
plt.subplot(1,2,2)
|
85 |
plt.imshow(label)
|
86 |
plt.axis('off')
|
87 |
-
plt.savefig('pred.png',bbox_inches='tight', pad_inches
|
88 |
return "pred.png"
|
89 |
|
90 |
-
AMOS2022 = {1: 'liver', 2: 'right kidney', 3: 'spleen', 4: 'pancreas', 5: 'aorta', 6: 'inferior vena cava',
|
91 |
-
7: 'right adrenal gland', 8: 'left adrenal gland',
|
92 |
-
9: 'gall bladder', 10: 'esophagus', 11: 'stomach', 12: 'duodenum', 13: 'left kidney',
|
93 |
-
14: 'bladder', 15: 'prostate'}
|
94 |
-
ACDC = {1: 'right ventricle', 2: 'myocardium', 3: 'left ventricle'}
|
95 |
-
LiTS2017 = {1: 'liver', 2: 'liver tumor'}
|
96 |
-
KiTS2019 = {1: 'kidney', 2: 'kidney tumor'}
|
97 |
-
|
98 |
-
# 预定义的输入案例
|
99 |
-
examples = [
|
100 |
-
['polyp colonoscopy', 'polyp', 'CVC-ClinicDB'],
|
101 |
-
['breast ultrasound', 'normal', 'BUSI'],
|
102 |
-
['breast ultrasound', 'breast tumor', 'BUSI'],
|
103 |
-
]
|
104 |
-
|
105 |
-
# 创建 Gradio 接口
|
106 |
-
interface = gr.Interface(
|
107 |
-
fn=generate_image,
|
108 |
-
inputs=[gr.Dropdown(["polyp colonoscopy"], label="organ"),
|
109 |
-
gr.Dropdown(["polyp"], label="kind"),
|
110 |
-
gr.Dropdown(['CVC-ClinicDB'], label="keys")],
|
111 |
-
outputs=gr.Image(label="Visualization"), # 返回 Matplotlib 渲染的图片
|
112 |
-
examples=examples # 添加输入案例
|
113 |
-
)
|
114 |
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
|
12 |
def cached_download(*args, **kwargs):
|
13 |
print("Warning: cached_download is deprecated, using hf_hub_download instead.")
|
14 |
return hf_hub_download(*args, **kwargs)
|
15 |
+
|
16 |
+
|
17 |
import sys
|
18 |
+
|
19 |
sys.modules["huggingface_hub.cached_download"] = cached_download
|
20 |
|
21 |
from diffusers import AutoencoderKL, DDPMScheduler
|
|
|
25 |
from diffusers import DDIMScheduler
|
26 |
from StableDiffusion.Our_Pipe import StableDiffusionPipeline
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model_repo_id = "runwayml/stable-diffusion-v1-5"
|
29 |
medsegfactory_id = "JohnWeck/StableDiffusion"
|
30 |
filename = 'checkpoint-300.pth'
|
|
|
66 |
|
67 |
pipeline = MedicalPipeline(pipe, device)
|
68 |
|
69 |
+
# 定义 keys 与 organ 及 kind 的映射
|
70 |
+
keys_to_organ_kind = {
|
71 |
+
"CVC-ClinicDB": {
|
72 |
+
"organs": ["polyp colonoscopy"],
|
73 |
+
"kinds": {"polyp colonoscopy": ["polyp"]}
|
74 |
+
},
|
75 |
+
"BUSI": {
|
76 |
+
"organs": ["breast ultrasound"],
|
77 |
+
"kinds": {"breast ultrasound": ["normal", "breast tumor"]}
|
78 |
+
},
|
79 |
+
"LiTS2017": {
|
80 |
+
"organs": ["abdomen CT scans"],
|
81 |
+
"kinds": {"abdomen CT scans": ["liver","liver tumor"]}
|
82 |
+
},
|
83 |
+
"KiTS2019": {
|
84 |
+
"organs": ["abdomen CT scans"],
|
85 |
+
"kinds": {"abdomen CT scans": ["kidney","kidney tumor"]}
|
86 |
+
},
|
87 |
+
"ACDC": {
|
88 |
+
"organs": ["cardiovascular ventricle mri"],
|
89 |
+
"kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]}
|
90 |
+
},
|
91 |
+
"AMOS2022": {
|
92 |
+
"organs": ["abdomen CT scans"],
|
93 |
+
"kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava",
|
94 |
+
"right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney",
|
95 |
+
"bladder", "prostate"]}
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
def update_organ_and_kind(selected_key):
|
101 |
+
organs = keys_to_organ_kind[selected_key]["organs"]
|
102 |
+
first_organ = organs[0] # 默认选第一个 organ
|
103 |
+
kinds = keys_to_organ_kind[selected_key]["kinds"][first_organ]
|
104 |
+
return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds)
|
105 |
+
|
106 |
+
|
107 |
+
def update_kind(selected_key, selected_organ):
|
108 |
+
kinds = keys_to_organ_kind[selected_key]["kinds"][selected_organ]
|
109 |
+
return gr.update(choices=kinds, value=kinds)
|
110 |
+
|
111 |
+
|
112 |
def generate_image(organ, kind, keys):
|
|
|
113 |
image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
|
114 |
+
plt.subplot(1, 2, 1)
|
115 |
plt.imshow(image)
|
116 |
plt.axis('off')
|
117 |
+
plt.subplot(1, 2, 2)
|
118 |
plt.imshow(label)
|
119 |
plt.axis('off')
|
120 |
+
plt.savefig('pred.png', bbox_inches='tight', pad_inches=0)
|
121 |
return "pred.png"
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
with gr.Blocks() as demo:
|
125 |
+
keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB")
|
126 |
+
organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ")
|
127 |
+
kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind")
|
128 |
+
|
129 |
+
keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox])
|
130 |
+
organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox)
|
131 |
+
|
132 |
+
generate_button = gr.Button("Generate Image")
|
133 |
+
output_image = gr.Image(label="Visualization")
|
134 |
+
|
135 |
+
generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image)
|
136 |
+
|
137 |
+
demo.launch()
|