import os import random from typing import List import torch.nn.functional as F import torch import numpy as np from diffusers import StableDiffusionPipeline from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from skimage.color import rgb2gray def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") def get_generator(seed, device): if seed is not None: if isinstance(seed, list): generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] else: generator = torch.Generator(device).manual_seed(seed) else: generator = None return generator class MedicalPipeline: def __init__(self, sd_pipe, device): self.device = device self.pipe = sd_pipe.to(self.device) self.AMOS2022 = {1: 'liver', 2: 'right kidney', 3: 'spleen', 4: 'pancreas', 5: 'aorta', 6: 'inferior vena cava', 7: 'right adrenal gland', 8: 'left adrenal gland', 9: 'gall bladder', 10: 'esophagus', 11: 'stomach', 12: 'duodenum', 13: 'left kidney', 14: 'bladder', 15: 'prostate'} self.ACDC = {1: 'right ventricle', 2: 'myocardium', 3: 'left ventricle'} self.BUSI = {0: 'normal', 1: 'breast tumor'} self.CVC_ClinicDB = {1: 'polyp'} self.kvasir_seg = {1: 'polyp'} self.LiTS2017 = {1: 'liver', 2: 'liver tumor'} self.KiTS2019 = {1: 'kidney', 2: 'kidney tumor'} def numpy_to_pil(self, images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") return images def map_to_classes(self, label_array, max_pixel): """将标签值映射到 [0, 1, 2, ..., num_classes-1] 范围""" return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8) def get_random_values(self, my_dict): """ 随机从字典中选择一个或多个不重复的值 :param my_dict: 字典 :return: 随机选择的一个或多个不重复的值的列表 """ values_list = list(my_dict.values()) # 获取字典的所有值并转换为列表 # 随机决定选择一个还是多个值 num_choices = random.randint(1, len(values_list)) # 随机选择1到len(values_list)个值 # 使用 random.sample() 来确保不重复选择 kinds = random.sample(values_list, num_choices) kind = '' for k in kinds: if kind == '': kind = k else: kind = kind + ',' + k return kind def generate( self, organ=None, kind=None, keys=None, negative_prompt=None, height=256, width=256, num_samples=1, seed=None, guidance_scale=7.5, num_inference_steps=50, **kwargs, ): img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples # keys = prompts['key'] with torch.inference_mode(): img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt( img_prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) mask_prompt_embeds_, mask_negative_prompt_embeds_ = self.pipe.encode_prompt( mask_prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([img_prompt_embeds_, img_negative_prompt_embeds_], dim=0) negative_prompt_embeds = torch.cat([mask_prompt_embeds_, mask_negative_prompt_embeds_], dim=0) generator = get_generator(seed, self.device) data = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, output_type='np', **kwargs, ).images # num = data.shape[0] # index = int(num // 2) image, label = data[0], data[1] label = rgb2gray(label) image = self.numpy_to_pil(image).squeeze() if keys == 'AMOS2022': label = self.map_to_classes(label, 15) elif keys == 'ACDC': label = self.map_to_classes(label, 3) elif keys == 'BUSI': label = self.map_to_classes(label, 1) elif keys == 'CVC-ClinicDB': label = self.map_to_classes(label, 1) elif keys == 'kvasir-seg': label = self.map_to_classes(label, 1) elif keys == 'LiTS2017': label = self.map_to_classes(label, 2) elif keys == 'KiTS2019': label = self.map_to_classes(label, 2) return image, label