Commit
·
8cfdc75
1
Parent(s):
452abeb
finish
Browse files- parti_prompts.py +6 -4
parti_prompts.py
CHANGED
|
@@ -35,15 +35,16 @@ def get_karlo_eval(ckpt):
|
|
| 35 |
return karlo_eval
|
| 36 |
|
| 37 |
def get_if_eval(ckpt):
|
| 38 |
-
pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, torch_dtype=torch.float16)
|
| 39 |
pipe_low.enable_model_cpu_offload()
|
| 40 |
|
| 41 |
-
pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16)
|
| 42 |
pipe_up.enable_model_cpu_offload()
|
| 43 |
|
| 44 |
def if_eval(prompt, generator=None):
|
| 45 |
-
|
| 46 |
-
images =
|
|
|
|
| 47 |
return images
|
| 48 |
|
| 49 |
return if_eval
|
|
@@ -69,6 +70,7 @@ if __name__ == "__main__":
|
|
| 69 |
args = parser.parse_args()
|
| 70 |
|
| 71 |
dataset = load_dataset("nateraw/parti-prompts")["train"]
|
|
|
|
| 72 |
|
| 73 |
eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
|
| 74 |
|
|
|
|
| 35 |
return karlo_eval
|
| 36 |
|
| 37 |
def get_if_eval(ckpt):
|
| 38 |
+
pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, watermarker=None, torch_dtype=torch.float16, variant="fp16")
|
| 39 |
pipe_low.enable_model_cpu_offload()
|
| 40 |
|
| 41 |
+
pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, watermarker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16, variant="fp16")
|
| 42 |
pipe_up.enable_model_cpu_offload()
|
| 43 |
|
| 44 |
def if_eval(prompt, generator=None):
|
| 45 |
+
prompt_embeds, negative_prompt_embeds = pipe_low.encode_prompt(prompt)
|
| 46 |
+
images = pipe_low(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, output_type="pt").images
|
| 47 |
+
images = pipe_up(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image=images, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator).images
|
| 48 |
return images
|
| 49 |
|
| 50 |
return if_eval
|
|
|
|
| 70 |
args = parser.parse_args()
|
| 71 |
|
| 72 |
dataset = load_dataset("nateraw/parti-prompts")["train"]
|
| 73 |
+
# dataset = dataset.select(range(4))
|
| 74 |
|
| 75 |
eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
|
| 76 |
|