Spaces:
Running
Running
import os | |
import random | |
import argparse | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
# Caching function setup | |
def cached_download(*args, **kwargs): | |
print("Warning: cached_download is deprecated, using hf_hub_download instead.") | |
return hf_hub_download(*args, **kwargs) | |
import sys | |
sys.modules["huggingface_hub.cached_download"] = cached_download | |
from diffusers import AutoencoderKL, DDPMScheduler | |
from StableDiffusion.Our_UNet import UNet2DConditionModel | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection | |
from medical_pipeline import MedicalPipeline | |
from diffusers import DDIMScheduler | |
from StableDiffusion.Our_Pipe import StableDiffusionPipeline | |
model_repo_id = "runwayml/stable-diffusion-v1-5" | |
medsegfactory_id = "JohnWeck/StableDiffusion" | |
filename = 'checkpoint-300.pth' | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae") | |
unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet") | |
medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename) | |
unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu')) | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
weight_dtype = torch.float32 | |
unet.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=weight_dtype) | |
text_encoder.to(device, dtype=weight_dtype) | |
sd_noise_scheduler = DDIMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
) | |
# load SD pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_repo_id, | |
torch_dtype=torch.float32, | |
unet=unet, | |
scheduler=sd_noise_scheduler, | |
feature_extractor=None, | |
safety_checker=None | |
) | |
pipeline = MedicalPipeline(pipe, device) | |
# 定义 keys 与 organ 及 kind 的映射 | |
keys_to_organ_kind = { | |
"CVC-ClinicDB": { | |
"organs": ["polyp colonoscopy"], | |
"kinds": {"polyp colonoscopy": ["polyp"]} | |
}, | |
"BUSI": { | |
"organs": ["breast ultrasound"], | |
"kinds": {"breast ultrasound": ["normal", "breast tumor"]} | |
}, | |
"LiTS2017": { | |
"organs": ["abdomen CT scans"], | |
"kinds": {"abdomen CT scans": ["liver","liver tumor"]} | |
}, | |
"KiTS2019": { | |
"organs": ["abdomen CT scans"], | |
"kinds": {"abdomen CT scans": ["kidney","kidney tumor"]} | |
}, | |
"ACDC": { | |
"organs": ["cardiovascular ventricle mri"], | |
"kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]} | |
}, | |
"AMOS2022": { | |
"organs": ["abdomen CT scans"], | |
"kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava", | |
"right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney", | |
"bladder", "prostate"]} | |
} | |
} | |
def update_organ_and_kind(selected_key): | |
"""更新 organ 和 kind 的选项,并确保 organ 正确更新""" | |
organs = keys_to_organ_kind[selected_key]["organs"] | |
first_organ = organs[0] if organs else "" # 选择第一个 organ | |
kinds = keys_to_organ_kind[selected_key]["kinds"].get(first_organ, []) # 确保 kinds 不为空 | |
return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds) | |
def update_kind(selected_key, selected_organ): | |
"""更新 kind 的选项""" | |
kinds = keys_to_organ_kind[selected_key]["kinds"].get(selected_organ, []) | |
return gr.update(choices=kinds, value=kinds) | |
def generate_image(organ, kinds, keys): | |
kind = ",".join(kinds) | |
print(f"Debug Info -> Organ: {organ}, Kind: {kind}, Keys: {keys}") | |
image, label = pipeline.generate(organ=organ, kind=kind, keys=keys) | |
plt.subplot(1, 2, 1) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.subplot(1, 2, 2) | |
plt.imshow(label) | |
plt.axis('off') | |
plt.savefig('pred.png', bbox_inches='tight', pad_inches=0) | |
return "pred.png" | |
with gr.Blocks() as demo: | |
gr.Markdown("### 📌 Note: This app is running under free CPU. Make sure to provide reasonable kind combinations.") | |
keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB") | |
organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ", value=keys_to_organ_kind["CVC-ClinicDB"]["organs"][0]) | |
kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind") | |
# Update organ and kind based on key | |
keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox]) | |
organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox) | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Visualization") | |
generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image) | |
demo.launch() | |