import os import random import argparse import numpy as np import matplotlib.pyplot as plt import torch import gradio as gr from huggingface_hub import hf_hub_download # Caching function setup def cached_download(*args, **kwargs): print("Warning: cached_download is deprecated, using hf_hub_download instead.") return hf_hub_download(*args, **kwargs) import sys sys.modules["huggingface_hub.cached_download"] = cached_download from diffusers import AutoencoderKL, DDPMScheduler from StableDiffusion.Our_UNet import UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from medical_pipeline import MedicalPipeline from diffusers import DDIMScheduler from StableDiffusion.Our_Pipe import StableDiffusionPipeline model_repo_id = "runwayml/stable-diffusion-v1-5" medsegfactory_id = "JohnWeck/StableDiffusion" filename = 'checkpoint-300.pth' device = "cuda" if torch.cuda.is_available() else "cpu" text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae") unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet") medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename) unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu')) vae.requires_grad_(False) text_encoder.requires_grad_(False) weight_dtype = torch.float32 unet.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) sd_noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) # load SD pipeline pipe = StableDiffusionPipeline.from_pretrained( model_repo_id, torch_dtype=torch.float32, unet=unet, scheduler=sd_noise_scheduler, feature_extractor=None, safety_checker=None ) pipeline = MedicalPipeline(pipe, device) # 定义 keys 与 organ 及 kind 的映射 keys_to_organ_kind = { "CVC-ClinicDB": { "organs": ["polyp colonoscopy"], "kinds": {"polyp colonoscopy": ["polyp"]} }, "BUSI": { "organs": ["breast ultrasound"], "kinds": {"breast ultrasound": ["normal", "breast tumor"]} }, "LiTS2017": { "organs": ["abdomen CT scans"], "kinds": {"abdomen CT scans": ["liver","liver tumor"]} }, "KiTS2019": { "organs": ["abdomen CT scans"], "kinds": {"abdomen CT scans": ["kidney","kidney tumor"]} }, "ACDC": { "organs": ["cardiovascular ventricle mri"], "kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]} }, "AMOS2022": { "organs": ["abdomen CT scans"], "kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava", "right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney", "bladder", "prostate"]} } } def update_organ_and_kind(selected_key): """更新 organ 和 kind 的选项,并确保 organ 正确更新""" organs = keys_to_organ_kind[selected_key]["organs"] first_organ = organs[0] if organs else "" # 选择第一个 organ kinds = keys_to_organ_kind[selected_key]["kinds"].get(first_organ, []) # 确保 kinds 不为空 return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds) def update_kind(selected_key, selected_organ): """更新 kind 的选项""" kinds = keys_to_organ_kind[selected_key]["kinds"].get(selected_organ, []) return gr.update(choices=kinds, value=kinds) def generate_image(organ, kinds, keys): kind = ",".join(kinds) print(f"Debug Info -> Organ: {organ}, Kind: {kind}, Keys: {keys}") image, label = pipeline.generate(organ=organ, kind=kind, keys=keys) plt.subplot(1, 2, 1) plt.imshow(image) plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(label) plt.axis('off') plt.savefig('pred.png', bbox_inches='tight', pad_inches=0) return "pred.png" with gr.Blocks() as demo: gr.Markdown("### 📌 Note: This app is running under free CPU. Make sure to provide reasonable kind combinations.") keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB") organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ", value=keys_to_organ_kind["CVC-ClinicDB"]["organs"][0]) kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind") # Update organ and kind based on key keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox]) organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox) generate_button = gr.Button("Generate Image") output_image = gr.Image(label="Visualization") generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image) demo.launch()