256
Browse files
app.py
CHANGED
|
@@ -27,7 +27,7 @@ from diffusion.flow_matching import ODEEulerFlowMatchingSolver
|
|
| 27 |
import utils
|
| 28 |
import libs.autoencoder
|
| 29 |
from libs.clip import FrozenCLIPEmbedder
|
| 30 |
-
from configs import t2i_512px_clip_dimr
|
| 31 |
|
| 32 |
|
| 33 |
def unpreprocess(x: torch.Tensor) -> torch.Tensor:
|
|
@@ -93,7 +93,8 @@ def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
|
|
| 93 |
return context, token_mask, tokens, captions
|
| 94 |
|
| 95 |
# Load configuration and initialize models.
|
| 96 |
-
config_dict = t2i_512px_clip_dimr.get_config()
|
|
|
|
| 97 |
config = ml_collections.ConfigDict(config_dict)
|
| 98 |
|
| 99 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -108,7 +109,8 @@ MAX_IMAGE_SIZE = 1024 # Currently not used.
|
|
| 108 |
|
| 109 |
# Load the main diffusion model.
|
| 110 |
repo_id = "QHL067/CrossFlow"
|
| 111 |
-
filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
|
|
|
| 112 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 113 |
nnet = utils.get_nnet(**config.nnet)
|
| 114 |
nnet = nnet.to(device)
|
|
|
|
| 27 |
import utils
|
| 28 |
import libs.autoencoder
|
| 29 |
from libs.clip import FrozenCLIPEmbedder
|
| 30 |
+
from configs import t2i_512px_clip_dimr, t2i_256px_clip_dimr
|
| 31 |
|
| 32 |
|
| 33 |
def unpreprocess(x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 93 |
return context, token_mask, tokens, captions
|
| 94 |
|
| 95 |
# Load configuration and initialize models.
|
| 96 |
+
# config_dict = t2i_512px_clip_dimr.get_config()
|
| 97 |
+
config_dict = t2i_256px_clip_dimr.get_config()
|
| 98 |
config = ml_collections.ConfigDict(config_dict)
|
| 99 |
|
| 100 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 109 |
|
| 110 |
# Load the main diffusion model.
|
| 111 |
repo_id = "QHL067/CrossFlow"
|
| 112 |
+
# filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
| 113 |
+
filename = "pretrained_models/t2i_256px_clip_dimr.pth"
|
| 114 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 115 |
nnet = utils.get_nnet(**config.nnet)
|
| 116 |
nnet = nnet.to(device)
|