rezashkv commited on
Commit
08cdffb
Β·
verified Β·
1 Parent(s): 2dc812e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -0
README.md CHANGED
@@ -88,6 +88,52 @@ APTP
88
  β”œβ”€β”€ ...
89
  └── arch7
90
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  ## Uses
 
88
  β”œβ”€β”€ ...
89
  └── arch7
90
  ```
91
+ ## Simple Inference Example
92
+
93
+ Make sure follow the installation instructions in the [Github Repository](https://github.com/rezashkv/diffusion_pruning) to install pdm from source.
94
+
95
+ ```python
96
+ from diffusers import StableDiffusionPipeline, PNDMScheduler
97
+ from pdm.models import HyperStructure, StructureVectorQuantizer, UNet2DConditionModelPruned
98
+ from pdm.utils.data_utils import get_mpnet_embeddings
99
+ from transformers import AutoTokenizer, AutoModel
100
+ import torch
101
+
102
+ prompt_encoder_model_name_or_path = "sentence-transformers/all-mpnet-base-v2"
103
+ prompt_encoder_tokenizer = AutoTokenizer.from_pretrained(prompt_encoder_model_name_or_path)
104
+ prompt_encoder = AutoModel.from_pretrained(prompt_encoder_model_name_or_path)
105
+
106
+ aptp_model_name_or_path = f"rezashkv/APTP"
107
+ aptp_variant = "APTP-Base-CC3M"
108
+ hyper_net = HyperStructure.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/hypernet")
109
+ quantizer = StructureVectorQuantizer.from_pretrained(aptp_model_name_or_path, subfolder=f"{aptp_variant}/quantizer")
110
+
111
+ prompts = ["a woman on a white background looks down and away from the camera the a forlorn look on her face"]
112
+ prompt_embedding = get_mpnet_embeddings(prompts, prompt_encoder, prompt_encoder_tokenizer)
113
+
114
+ arch_embedding = hyper_net(prompt_embedding)
115
+ expert_id = quantizer.get_cosine_sim_min_encoding_indices(arch_embedding)[0].item()
116
+
117
+ sd_model_name_or_path = "stabilityai/stable-diffusion-2-1"
118
+
119
+ unet = UNet2DConditionModelPruned.from_pretrained(aptp_model_name_or_path,
120
+ subfolder=f"{aptp_variant}/arch{expert_id}/checkpoint-30000/unet")
121
+
122
+ noise_scheduler = PNDMScheduler.from_pretrained(sd_model_name_or_path, subfolder="scheduler")
123
+ pipeline = StableDiffusionPipeline.from_pretrained(sd_model_name_or_path, unet=unet, scheduler=noise_scheduler)
124
+
125
+ pipeline.to('cuda')
126
+ generator = torch.Generator(device='cuda').manual_seed(43)
127
+
128
+ image = pipeline(
129
+ prompt=prompts[0],
130
+ guidance_scale=7.5,
131
+ generator=generator,
132
+ output_type='pil',
133
+ ).images[0]
134
+
135
+ image.save("image.png")
136
+ ```
137
 
138
 
139
  ## Uses