Spaces:
Running
Running
import os | |
import random | |
from typing import List | |
import torch.nn.functional as F | |
import torch | |
import numpy as np | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image | |
from safetensors import safe_open | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from skimage.color import rgb2gray | |
def is_torch2_available(): | |
return hasattr(F, "scaled_dot_product_attention") | |
def get_generator(seed, device): | |
if seed is not None: | |
if isinstance(seed, list): | |
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] | |
else: | |
generator = torch.Generator(device).manual_seed(seed) | |
else: | |
generator = None | |
return generator | |
class MedicalPipeline: | |
def __init__(self, sd_pipe, device): | |
self.device = device | |
self.pipe = sd_pipe.to(self.device) | |
self.AMOS2022 = {1: 'liver', 2: 'right kidney', 3: 'spleen', 4: 'pancreas', 5: 'aorta', 6: 'inferior vena cava', | |
7: 'right adrenal gland', 8: 'left adrenal gland', | |
9: 'gall bladder', 10: 'esophagus', 11: 'stomach', 12: 'duodenum', 13: 'left kidney', | |
14: 'bladder', 15: 'prostate'} | |
self.ACDC = {1: 'right ventricle', 2: 'myocardium', 3: 'left ventricle'} | |
self.BUSI = {0: 'normal', 1: 'breast tumor'} | |
self.CVC_ClinicDB = {1: 'polyp'} | |
self.kvasir_seg = {1: 'polyp'} | |
self.LiTS2017 = {1: 'liver', 2: 'liver tumor'} | |
self.KiTS2019 = {1: 'kidney', 2: 'kidney tumor'} | |
def numpy_to_pil(self, images): | |
""" | |
Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
return images | |
def map_to_classes(self, label_array, max_pixel): | |
"""将标签值映射到 [0, 1, 2, ..., num_classes-1] 范围""" | |
return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8) | |
def get_random_values(self, my_dict): | |
""" | |
随机从字典中选择一个或多个不重复的值 | |
:param my_dict: 字典 | |
:return: 随机选择的一个或多个不重复的值的列表 | |
""" | |
values_list = list(my_dict.values()) # 获取字典的所有值并转换为列表 | |
# 随机决定选择一个还是多个值 | |
num_choices = random.randint(1, len(values_list)) # 随机选择1到len(values_list)个值 | |
# 使用 random.sample() 来确保不重复选择 | |
kinds = random.sample(values_list, num_choices) | |
kind = '' | |
for k in kinds: | |
if kind == '': | |
kind = k | |
else: | |
kind = kind + ',' + k | |
return kind | |
def generate( | |
self, | |
organ=None, | |
kind=None, | |
keys=None, | |
negative_prompt=None, | |
height=256, | |
width=256, | |
num_samples=1, | |
seed=None, | |
guidance_scale=7.5, | |
num_inference_steps=50, | |
**kwargs, | |
): | |
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples | |
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples | |
# keys = prompts['key'] | |
with torch.inference_mode(): | |
img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt( | |
img_prompt, | |
device=self.device, | |
num_images_per_prompt=num_samples, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
mask_prompt_embeds_, mask_negative_prompt_embeds_ = self.pipe.encode_prompt( | |
mask_prompt, | |
device=self.device, | |
num_images_per_prompt=num_samples, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt_embeds = torch.cat([img_prompt_embeds_, img_negative_prompt_embeds_], dim=0) | |
negative_prompt_embeds = torch.cat([mask_prompt_embeds_, mask_negative_prompt_embeds_], dim=0) | |
generator = get_generator(seed, self.device) | |
data = self.pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
output_type='np', | |
**kwargs, | |
).images | |
# num = data.shape[0] | |
# index = int(num // 2) | |
image, label = data[0], data[1] | |
label = rgb2gray(label) | |
image = self.numpy_to_pil(image).squeeze() | |
if keys == 'AMOS2022': | |
label = self.map_to_classes(label, 15) | |
elif keys == 'ACDC': | |
label = self.map_to_classes(label, 3) | |
elif keys == 'BUSI': | |
label = self.map_to_classes(label, 1) | |
elif keys == 'CVC-ClinicDB': | |
label = self.map_to_classes(label, 1) | |
elif keys == 'kvasir-seg': | |
label = self.map_to_classes(label, 1) | |
elif keys == 'LiTS2017': | |
label = self.map_to_classes(label, 2) | |
elif keys == 'KiTS2019': | |
label = self.map_to_classes(label, 2) | |
return image, label | |