Spaces:
Running
Running
File size: 5,464 Bytes
9996200 435691a 9996200 435691a c362d04 9996200 f9fdef1 9996200 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import random
from typing import List
import torch.nn.functional as F
import torch
import numpy as np
from diffusers import StableDiffusionPipeline
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from skimage.color import rgb2gray
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
def get_generator(seed, device):
if seed is not None:
if isinstance(seed, list):
generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
else:
generator = torch.Generator(device).manual_seed(seed)
else:
generator = None
return generator
class MedicalPipeline:
def __init__(self, sd_pipe, device):
self.device = device
self.pipe = sd_pipe.to(self.device)
self.AMOS2022 = {1: 'liver', 2: 'right kidney', 3: 'spleen', 4: 'pancreas', 5: 'aorta', 6: 'inferior vena cava',
7: 'right adrenal gland', 8: 'left adrenal gland',
9: 'gall bladder', 10: 'esophagus', 11: 'stomach', 12: 'duodenum', 13: 'left kidney',
14: 'bladder', 15: 'prostate'}
self.ACDC = {1: 'right ventricle', 2: 'myocardium', 3: 'left ventricle'}
self.BUSI = {0: 'normal', 1: 'breast tumor'}
self.CVC_ClinicDB = {1: 'polyp'}
self.kvasir_seg = {1: 'polyp'}
self.LiTS2017 = {1: 'liver', 2: 'liver tumor'}
self.KiTS2019 = {1: 'kidney', 2: 'kidney tumor'}
def numpy_to_pil(self, images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
return images
def map_to_classes(self, label_array, max_pixel):
"""将标签值映射到 [0, 1, 2, ..., num_classes-1] 范围"""
return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8)
def get_random_values(self, my_dict):
"""
随机从字典中选择一个或多个不重复的值
:param my_dict: 字典
:return: 随机选择的一个或多个不重复的值的列表
"""
values_list = list(my_dict.values()) # 获取字典的所有值并转换为列表
# 随机决定选择一个还是多个值
num_choices = random.randint(1, len(values_list)) # 随机选择1到len(values_list)个值
# 使用 random.sample() 来确保不重复选择
kinds = random.sample(values_list, num_choices)
kind = ''
for k in kinds:
if kind == '':
kind = k
else:
kind = kind + ',' + k
return kind
def generate(
self,
organ=None,
kind=None,
keys=None,
negative_prompt=None,
height=256,
width=256,
num_samples=1,
seed=None,
guidance_scale=7.5,
num_inference_steps=50,
**kwargs,
):
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
# keys = prompts['key']
with torch.inference_mode():
img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt(
img_prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
mask_prompt_embeds_, mask_negative_prompt_embeds_ = self.pipe.encode_prompt(
mask_prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([img_prompt_embeds_, img_negative_prompt_embeds_], dim=0)
negative_prompt_embeds = torch.cat([mask_prompt_embeds_, mask_negative_prompt_embeds_], dim=0)
generator = get_generator(seed, self.device)
data = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type='np',
**kwargs,
).images
# num = data.shape[0]
# index = int(num // 2)
image, label = data[0], data[1]
label = rgb2gray(label)
image = self.numpy_to_pil(image).squeeze()
if keys == 'AMOS2022':
label = self.map_to_classes(label, 15)
elif keys == 'ACDC':
label = self.map_to_classes(label, 3)
elif keys == 'BUSI':
label = self.map_to_classes(label, 1)
elif keys == 'CVC-ClinicDB':
label = self.map_to_classes(label, 1)
elif keys == 'kvasir-seg':
label = self.map_to_classes(label, 1)
elif keys == 'LiTS2017':
label = self.map_to_classes(label, 2)
elif keys == 'KiTS2019':
label = self.map_to_classes(label, 2)
return image, label
|