JohnWeck commited on
Commit
c362d04
·
verified ·
1 Parent(s): a7ebb0c

Update medical_pipeline.py

Browse files
Files changed (1) hide show
  1. medical_pipeline.py +3 -102
medical_pipeline.py CHANGED
@@ -91,109 +91,10 @@ class MedicalPipeline:
91
  num_inference_steps=50,
92
  **kwargs,
93
  ):
94
- if organ is None:
95
- if not isinstance(keys, List):
96
- if keys is None:
97
- idx = random.randint(1, 7)
98
- if idx == 1:
99
- keys = 'AMOS2022'
100
- organ = 'abdomen CT scans'
101
- kind = self.get_random_values(self.AMOS2022)
102
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
103
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
104
- elif idx == 2:
105
- keys = 'BUSI'
106
- organ = 'breast ultrasound'
107
- choice = random.random()
108
- if choice < 0.5:
109
- kind = 'normal'
110
- else:
111
- kind = 'breast tumor'
112
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
113
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
114
- elif idx == 3:
115
- keys = 'ACDC'
116
- organ = 'cardiovascular ventricle mri'
117
- kind = self.get_random_values(self.ACDC)
118
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
119
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
120
- elif idx == 4:
121
- keys = 'CVC-ClinicDB'
122
- organ = 'polyp colonoscopy'
123
- kind = 'polyp'
124
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
125
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
126
- elif idx == 5:
127
- keys = 'kvasir-seg'
128
- organ = 'polyp colonoscopy'
129
- kind = 'polyp'
130
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
131
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
132
- elif idx == 6:
133
- keys = 'LiTS2017'
134
- organ = 'abdomen CT scans'
135
- kind = self.get_random_values(self.LiTS2017)
136
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
137
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
138
- elif idx == 7:
139
- keys = 'KiTS2019'
140
- organ = 'abdomen CT scans'
141
- kind = self.get_random_values(self.KiTS2019)
142
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
143
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
144
- else:
145
- raise RuntimeError('no mode')
146
 
147
- print(img_prompt, mask_prompt)
148
-
149
- else:
150
- if keys == 'AMOS2022':
151
- organ = 'abdomen CT scans'
152
- kind = self.get_random_values(self.AMOS2022)
153
- elif keys == 'BUSI':
154
- organ = 'breast ultrasound'
155
- choice = random.random()
156
- if choice < 0.5:
157
- kind = 'normal'
158
- else:
159
- kind = 'breast tumor'
160
- elif keys == 'ACDC':
161
- organ = 'cardiovascular ventricle mri'
162
- kind = self.get_random_values(self.ACDC)
163
- elif keys == 'CVC-ClinicDB':
164
- organ = 'polyp colonoscopy'
165
- kind = 'polyp'
166
- elif keys == 'kvasir-seg':
167
- organ = 'polyp colonoscopy'
168
- kind = 'polyp'
169
- elif keys == 'LiTS2017':
170
- organ = 'abdomen CT scans'
171
- kind = self.get_random_values(self.LiTS2017)
172
- elif keys == 'KiTS2019':
173
- organ = 'abdomen CT scans'
174
- kind = self.get_random_values(self.KiTS2019)
175
- else:
176
- raise RuntimeError('undefined keys')
177
-
178
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
179
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
180
- keys = keys
181
- print(img_prompt, mask_prompt)
182
- else:
183
- # img_prompt, mask_prompt = [], []
184
- # for key in keys:
185
- # img_prompt.append(f'a photo of {organ} image, with {kind}.')
186
- # mask_prompt.append(f'a photo of {organ} label, with {kind}.')
187
- # img_prompt = img_prompt * num_samples
188
- # mask_prompt = mask_prompt * num_samples
189
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
190
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
191
-
192
- else:
193
-
194
- img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
195
- mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
196
- # keys = prompts['key']
197
 
198
  with torch.inference_mode():
199
  img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt(
 
91
  num_inference_steps=50,
92
  **kwargs,
93
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ img_prompt = [f'a photo of {organ} image, with {kind}.'] * num_samples
96
+ mask_prompt = [f'a photo of {organ} label, with {kind}.'] * num_samples
97
+ # keys = prompts['key']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  with torch.inference_mode():
100
  img_prompt_embeds_, img_negative_prompt_embeds_ = self.pipe.encode_prompt(