File size: 5,464 Bytes
9996200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435691a
 
9996200
 
 
 
 
 
 
 
 
 
435691a
c362d04
 
 
9996200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fdef1
9996200
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import random
from typing import List
import torch.nn.functional as F
import torch
import numpy as np
from diffusers import StableDiffusionPipeline
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from skimage.color import rgb2gray

def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")


def get_generator(seed, device):
    if seed is not None:
        if isinstance(seed, list):
            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator


class MedicalPipeline:
    def __init__(self, sd_pipe, device):
        self.device = device
        self.pipe = sd_pipe.to(self.device)
        self.AMOS2022 = {1: 'liver', 2: 'right kidney', 3: 'spleen', 4: 'pancreas', 5: 'aorta', 6: 'inferior vena cava',
                         7: 'right adrenal gland', 8: 'left adrenal gland',
                         9: 'gall bladder', 10: 'esophagus', 11: 'stomach', 12: 'duodenum', 13: 'left kidney',
                         14: 'bladder', 15: 'prostate'}
        self.ACDC = {1: 'right ventricle', 2: 'myocardium', 3: 'left ventricle'}

        self.BUSI = {0: 'normal', 1: 'breast tumor'}
        self.CVC_ClinicDB = {1: 'polyp'}
        self.kvasir_seg = {1: 'polyp'}

        self.LiTS2017 = {1: 'liver', 2: 'liver tumor'}
        self.KiTS2019 = {1: 'kidney', 2: 'kidney tumor'}

    def numpy_to_pil(self, images):
        """
        Convert a numpy image or a batch of images to a PIL image.
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        return images

    def map_to_classes(self, label_array, max_pixel):
        """将标签值映射到 [0, 1, 2, ..., num_classes-1] 范围"""
        return np.clip(np.round(label_array * (max_pixel)), 0, max_pixel).astype(np.uint8)

    def get_random_values(self, my_dict):
        """
        随机从字典中选择一个或多个不重复的值
        :param my_dict: 字典
        :return: 随机选择的一个或多个不重复的值的列表
        """
        values_list = list(my_dict.values())  # 获取字典的所有值并转换为列表
        # 随机决定选择一个还是多个值
        num_choices = random.randint(1, len(values_list))  # 随机选择1到len(values_list)个值
        # 使用 random.sample() 来确保不重复选择
        kinds = random.sample(values_list, num_choices)
        kind = ''

        for k in kinds:
            if kind == '':
                kind = k
            else:
                kind = kind + ',' + k

        return kind

    def generate(
            self,
            organ=None,
            kind=None,
            keys=None,
            negative_prompt=None,
            height=256,
            width=256,
            num_samples=1,
            seed=None,
            guidance_scale=7.5,
            num_inference_steps=50,
            **kwargs,
    ):

        img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
        mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
        # keys = prompts['key']

        with torch.inference_mode():
            img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt(
                img_prompt,
                device=self.device,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            mask_prompt_embeds_, mask_negative_prompt_embeds_ = self.pipe.encode_prompt(
                mask_prompt,
                device=self.device,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([img_prompt_embeds_, img_negative_prompt_embeds_], dim=0)
            negative_prompt_embeds = torch.cat([mask_prompt_embeds_, mask_negative_prompt_embeds_], dim=0)

        generator = get_generator(seed, self.device)

        data = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            output_type='np',
            **kwargs,
        ).images

        # num = data.shape[0]
        # index = int(num // 2)

        image, label = data[0], data[1]

        label = rgb2gray(label)
        image = self.numpy_to_pil(image).squeeze()
        
        
        if keys == 'AMOS2022':
            label = self.map_to_classes(label, 15)
        elif keys == 'ACDC':
            label = self.map_to_classes(label, 3)
        elif keys == 'BUSI':
            label = self.map_to_classes(label, 1)
        elif keys == 'CVC-ClinicDB':
            label = self.map_to_classes(label, 1)
        elif keys == 'kvasir-seg':
            label = self.map_to_classes(label, 1)
        elif keys == 'LiTS2017':
            label = self.map_to_classes(label, 2)
        elif keys == 'KiTS2019':
            label = self.map_to_classes(label, 2)

        return image, label