Spaces:
Sleeping
Sleeping
Update medical_pipeline.py
Browse files- medical_pipeline.py +3 -102
medical_pipeline.py
CHANGED
@@ -91,109 +91,10 @@ class MedicalPipeline:
|
|
91 |
num_inference_steps=50,
|
92 |
**kwargs,
|
93 |
):
|
94 |
-
if organ is None:
|
95 |
-
if not isinstance(keys, List):
|
96 |
-
if keys is None:
|
97 |
-
idx = random.randint(1, 7)
|
98 |
-
if idx == 1:
|
99 |
-
keys = 'AMOS2022'
|
100 |
-
organ = 'abdomen CT scans'
|
101 |
-
kind = self.get_random_values(self.AMOS2022)
|
102 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
103 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
104 |
-
elif idx == 2:
|
105 |
-
keys = 'BUSI'
|
106 |
-
organ = 'breast ultrasound'
|
107 |
-
choice = random.random()
|
108 |
-
if choice < 0.5:
|
109 |
-
kind = 'normal'
|
110 |
-
else:
|
111 |
-
kind = 'breast tumor'
|
112 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
113 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
114 |
-
elif idx == 3:
|
115 |
-
keys = 'ACDC'
|
116 |
-
organ = 'cardiovascular ventricle mri'
|
117 |
-
kind = self.get_random_values(self.ACDC)
|
118 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
119 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
120 |
-
elif idx == 4:
|
121 |
-
keys = 'CVC-ClinicDB'
|
122 |
-
organ = 'polyp colonoscopy'
|
123 |
-
kind = 'polyp'
|
124 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
125 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
126 |
-
elif idx == 5:
|
127 |
-
keys = 'kvasir-seg'
|
128 |
-
organ = 'polyp colonoscopy'
|
129 |
-
kind = 'polyp'
|
130 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
131 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
132 |
-
elif idx == 6:
|
133 |
-
keys = 'LiTS2017'
|
134 |
-
organ = 'abdomen CT scans'
|
135 |
-
kind = self.get_random_values(self.LiTS2017)
|
136 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
137 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
138 |
-
elif idx == 7:
|
139 |
-
keys = 'KiTS2019'
|
140 |
-
organ = 'abdomen CT scans'
|
141 |
-
kind = self.get_random_values(self.KiTS2019)
|
142 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
143 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
144 |
-
else:
|
145 |
-
raise RuntimeError('no mode')
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
if keys == 'AMOS2022':
|
151 |
-
organ = 'abdomen CT scans'
|
152 |
-
kind = self.get_random_values(self.AMOS2022)
|
153 |
-
elif keys == 'BUSI':
|
154 |
-
organ = 'breast ultrasound'
|
155 |
-
choice = random.random()
|
156 |
-
if choice < 0.5:
|
157 |
-
kind = 'normal'
|
158 |
-
else:
|
159 |
-
kind = 'breast tumor'
|
160 |
-
elif keys == 'ACDC':
|
161 |
-
organ = 'cardiovascular ventricle mri'
|
162 |
-
kind = self.get_random_values(self.ACDC)
|
163 |
-
elif keys == 'CVC-ClinicDB':
|
164 |
-
organ = 'polyp colonoscopy'
|
165 |
-
kind = 'polyp'
|
166 |
-
elif keys == 'kvasir-seg':
|
167 |
-
organ = 'polyp colonoscopy'
|
168 |
-
kind = 'polyp'
|
169 |
-
elif keys == 'LiTS2017':
|
170 |
-
organ = 'abdomen CT scans'
|
171 |
-
kind = self.get_random_values(self.LiTS2017)
|
172 |
-
elif keys == 'KiTS2019':
|
173 |
-
organ = 'abdomen CT scans'
|
174 |
-
kind = self.get_random_values(self.KiTS2019)
|
175 |
-
else:
|
176 |
-
raise RuntimeError('undefined keys')
|
177 |
-
|
178 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
179 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
180 |
-
keys = keys
|
181 |
-
print(img_prompt, mask_prompt)
|
182 |
-
else:
|
183 |
-
# img_prompt, mask_prompt = [], []
|
184 |
-
# for key in keys:
|
185 |
-
# img_prompt.append(f'a photo of {organ} image, with {kind}.')
|
186 |
-
# mask_prompt.append(f'a photo of {organ} label, with {kind}.')
|
187 |
-
# img_prompt = img_prompt * num_samples
|
188 |
-
# mask_prompt = mask_prompt * num_samples
|
189 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
190 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
191 |
-
|
192 |
-
else:
|
193 |
-
|
194 |
-
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
195 |
-
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
196 |
-
# keys = prompts['key']
|
197 |
|
198 |
with torch.inference_mode():
|
199 |
img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt(
|
|
|
91 |
num_inference_steps=50,
|
92 |
**kwargs,
|
93 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
|
96 |
+
mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
|
97 |
+
# keys = prompts['key']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
with torch.inference_mode():
|
100 |
img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt(
|