File size: 4,947 Bytes
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
dc32163
 
 
 
 
 
 
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
dc32163
 
eb710fe
dc32163
 
 
 
 
eb710fe
 
dc32163
 
 
 
 
 
 
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e642d4
eb710fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436eef9
eb710fe
436eef9
eb710fe
 
 
 
 
 
 
dc32163
eb710fe
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import sys
import os 
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights
from huggingface_hub import snapshot_download

#global device
#global generator 
#global unet
#global vae 
#global text_encoder
#global tokenizer
#global noise_scheduler
device = "cuda:0"
generator = torch.Generator(device=device)

models_path = snapshot_download(repo_id="Snapchat/w2w")

mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")

unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
#global network

#def sample_model():
#    global unet
#    del unet
#    global network
#    unet, _, _, _, _ = load_models(device)

def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
    #global device
    #global generator 
    #global unet
    #global vae 
    #global text_encoder
    #global tokenizer
    #global noise_scheduler
    generator = generator.manual_seed(seed)
    latents = torch.randn(
        (1, unet.in_channels, 512 // 8, 512 // 8),
        generator = generator,
        device = device
    ).bfloat16()
   
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
                            [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
                        )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    noise_scheduler.set_timesteps(ddim_steps) 
    latents = latents * noise_scheduler.init_noise_sigma
    
    for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
        with network:
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
        #guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]

    image = Image.fromarray((image * 255).round().astype("uint8"))

    return [image] 

with gr.Blocks() as demo:
    gr.Markdown("# <em>weights2weights</em> Demo")
    with gr.Row():
        with gr.Column():
            files = gr.Files(
                        label="Upload a photo of your face to invert, or sample a new model",
                        file_types=["image"]
                    )
            uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)

            sample = gr.Button("Sample New Model")

            with gr.Column(visible=False) as clear_button:
                remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
            prompt = gr.Textbox(label="Prompt",
                       info="Make sure to include 'sks person'" ,
                       placeholder="sks person", 
                       value="sks person")
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
            seed = gr.Number(value=5, label="Seed", interactive=True)
            cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
            steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)


            submit = gr.Button("Submit")

        with gr.Column():
            gallery = gr.Gallery(label="Generated Images")

        #sample.click(fn=sample_model)
        
        submit.click(fn=inference,
                    inputs=[prompt, negative_prompt, cfg, steps, seed],
                    outputs=gallery)
            
demo.launch(share=True)