512 model
Browse files
app.py
CHANGED
|
@@ -95,13 +95,16 @@ def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
|
|
| 95 |
# Load configuration and initialize models.
|
| 96 |
# config_dict = t2i_512px_clip_dimr.get_config()
|
| 97 |
config_dict = t2i_256px_clip_dimr.get_config()
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
logging.info(f"Using device: {device}")
|
| 102 |
|
| 103 |
# Freeze configuration.
|
| 104 |
-
|
|
|
|
| 105 |
|
| 106 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 107 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -112,11 +115,19 @@ repo_id = "QHL067/CrossFlow"
|
|
| 112 |
# filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
| 113 |
filename = "pretrained_models/t2i_256px_clip_dimr.pth"
|
| 114 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# Initialize text model.
|
| 122 |
llm = "clip"
|
|
@@ -170,6 +181,11 @@ def infer(
|
|
| 170 |
else:
|
| 171 |
assert num_of_interpolation == 3, "For arithmetic, please sample three images."
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# Get text embeddings and tokens.
|
| 174 |
_context, _token_mask, _token, _caption = get_caption(
|
| 175 |
llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
|
|
|
|
| 95 |
# Load configuration and initialize models.
|
| 96 |
# config_dict = t2i_512px_clip_dimr.get_config()
|
| 97 |
config_dict = t2i_256px_clip_dimr.get_config()
|
| 98 |
+
config_1 = ml_collections.ConfigDict(config_dict)
|
| 99 |
+
config_dict = t2i_512px_clip_dimr.get_config()
|
| 100 |
+
config_2 = ml_collections.ConfigDict(config_dict)
|
| 101 |
|
| 102 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
logging.info(f"Using device: {device}")
|
| 104 |
|
| 105 |
# Freeze configuration.
|
| 106 |
+
config_1 = ml_collections.FrozenConfigDict(config_1)
|
| 107 |
+
config_2 = ml_collections.FrozenConfigDict(config_2)
|
| 108 |
|
| 109 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 110 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 115 |
# filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
| 116 |
filename = "pretrained_models/t2i_256px_clip_dimr.pth"
|
| 117 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 118 |
+
nnet_1 = utils.get_nnet(**config_1.nnet)
|
| 119 |
+
nnet_1 = nnet_1.to(device)
|
| 120 |
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 121 |
+
nnet_1.load_state_dict(state_dict)
|
| 122 |
+
nnet_1.eval()
|
| 123 |
+
|
| 124 |
+
filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
| 125 |
+
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 126 |
+
nnet_2 = utils.get_nnet(**config_2.nnet)
|
| 127 |
+
nnet_2 = nnet.to(device)
|
| 128 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 129 |
+
nnet_2.load_state_dict(state_dict)
|
| 130 |
+
nnet_2.eval()
|
| 131 |
|
| 132 |
# Initialize text model.
|
| 133 |
llm = "clip"
|
|
|
|
| 181 |
else:
|
| 182 |
assert num_of_interpolation == 3, "For arithmetic, please sample three images."
|
| 183 |
|
| 184 |
+
if num_of_interpolation == 3:
|
| 185 |
+
nnet = nnet_2
|
| 186 |
+
else:
|
| 187 |
+
nnet = nnet_1
|
| 188 |
+
|
| 189 |
# Get text embeddings and tokens.
|
| 190 |
_context, _token_mask, _token, _caption = get_caption(
|
| 191 |
llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
|