rlindberg commited on
Commit
2bf57d0
·
1 Parent(s): 1a3451e

Create new file

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
4
+ os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
5
+
6
+ import argparse
7
+ from functools import partial
8
+ from pathlib import Path
9
+ import sys
10
+ sys.path.append('./cloob-latent-diffusion')
11
+ sys.path.append('./cloob-latent-diffusion/cloob-training')
12
+ sys.path.append('./cloob-latent-diffusion/latent-diffusion')
13
+ sys.path.append('./cloob-latent-diffusion/taming-transformers')
14
+ sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
15
+ from omegaconf import OmegaConf
16
+ from PIL import Image
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from torchvision import transforms
21
+ from torchvision.transforms import functional as TF
22
+ from tqdm import trange
23
+ from CLIP import clip
24
+ from cloob_training import model_pt, pretrained
25
+ import ldm.models.autoencoder
26
+ from diffusion import sampling, utils
27
+ import train_latent_diffusion as train
28
+ from huggingface_hub import hf_hub_url, cached_download
29
+ import random
30
+
31
+ # Download the model files
32
+ checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
33
+ ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
34
+ ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
35
+
36
+ # Define a few utility functions
37
+
38
+
39
+ def parse_prompt(prompt, default_weight=3.):
40
+ if prompt.startswith('http://') or prompt.startswith('https://'):
41
+ vals = prompt.rsplit(':', 2)
42
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
43
+ else:
44
+ vals = prompt.rsplit(':', 1)
45
+ vals = vals + ['', default_weight][len(vals):]
46
+ return vals[0], float(vals[1])