medsegfactory / app.py
JohnWeck's picture
Update app.py
08587f3 verified
import os
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
# Caching function setup
def cached_download(*args, **kwargs):
print("Warning: cached_download is deprecated, using hf_hub_download instead.")
return hf_hub_download(*args, **kwargs)
import sys
sys.modules["huggingface_hub.cached_download"] = cached_download
from diffusers import AutoencoderKL, DDPMScheduler
from StableDiffusion.Our_UNet import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from medical_pipeline import MedicalPipeline
from diffusers import DDIMScheduler
from StableDiffusion.Our_Pipe import StableDiffusionPipeline
model_repo_id = "runwayml/stable-diffusion-v1-5"
medsegfactory_id = "JohnWeck/StableDiffusion"
filename = 'checkpoint-300.pth'
device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae")
unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet")
medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename)
unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu'))
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
weight_dtype = torch.float32
unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
sd_noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_repo_id,
torch_dtype=torch.float32,
unet=unet,
scheduler=sd_noise_scheduler,
feature_extractor=None,
safety_checker=None
)
pipeline = MedicalPipeline(pipe, device)
# 定义 keys 与 organ 及 kind 的映射
keys_to_organ_kind = {
"CVC-ClinicDB": {
"organs": ["polyp colonoscopy"],
"kinds": {"polyp colonoscopy": ["polyp"]}
},
"BUSI": {
"organs": ["breast ultrasound"],
"kinds": {"breast ultrasound": ["normal", "breast tumor"]}
},
"LiTS2017": {
"organs": ["abdomen CT scans"],
"kinds": {"abdomen CT scans": ["liver","liver tumor"]}
},
"KiTS2019": {
"organs": ["abdomen CT scans"],
"kinds": {"abdomen CT scans": ["kidney","kidney tumor"]}
},
"ACDC": {
"organs": ["cardiovascular ventricle mri"],
"kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]}
},
"AMOS2022": {
"organs": ["abdomen CT scans"],
"kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava",
"right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney",
"bladder", "prostate"]}
}
}
def update_organ_and_kind(selected_key):
"""更新 organ 和 kind 的选项,并确保 organ 正确更新"""
organs = keys_to_organ_kind[selected_key]["organs"]
first_organ = organs[0] if organs else "" # 选择第一个 organ
kinds = keys_to_organ_kind[selected_key]["kinds"].get(first_organ, []) # 确保 kinds 不为空
return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds)
def update_kind(selected_key, selected_organ):
"""更新 kind 的选项"""
kinds = keys_to_organ_kind[selected_key]["kinds"].get(selected_organ, [])
return gr.update(choices=kinds, value=kinds)
def generate_image(organ, kinds, keys):
kind = ",".join(kinds)
print(f"Debug Info -> Organ: {organ}, Kind: {kind}, Keys: {keys}")
image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(label)
plt.axis('off')
plt.savefig('pred.png', bbox_inches='tight', pad_inches=0)
return "pred.png"
with gr.Blocks() as demo:
gr.Markdown("### 📌 Note: This app is running under free CPU. Make sure to provide reasonable kind combinations.")
keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB")
organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ", value=keys_to_organ_kind["CVC-ClinicDB"]["organs"][0])
kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind")
# Update organ and kind based on key
keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox])
organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox)
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Visualization")
generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image)
demo.launch()