QHL067 commited on
Commit
e5dbbbe
·
1 Parent(s): f2b01f3
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -121,13 +121,13 @@ state_dict = torch.load(checkpoint_path, map_location=device)
121
  nnet_1.load_state_dict(state_dict)
122
  nnet_1.eval()
123
 
124
- filename = "pretrained_models/t2i_512px_clip_dimr.pth"
125
- checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
126
- nnet_2 = utils.get_nnet(**config_2.nnet)
127
- nnet_2 = nnet_2.to(device)
128
- state_dict = torch.load(checkpoint_path, map_location=device)
129
- nnet_2.load_state_dict(state_dict)
130
- nnet_2.eval()
131
 
132
  # Initialize text model.
133
  llm = "clip"
@@ -181,10 +181,11 @@ def infer(
181
  else:
182
  assert num_of_interpolation == 3, "For arithmetic, please sample three images."
183
 
184
- if num_of_interpolation == 3:
185
- nnet = nnet_2
186
- else:
187
- nnet = nnet_1
 
188
 
189
  # Get text embeddings and tokens.
190
  _context, _token_mask, _token, _caption = get_caption(
@@ -301,7 +302,7 @@ examples_1 = [
301
  ]
302
 
303
  examples_2 = [
304
- ["A corgi in the park", "red hat"],
305
  ]
306
 
307
  css = """
 
121
  nnet_1.load_state_dict(state_dict)
122
  nnet_1.eval()
123
 
124
+ # filename = "pretrained_models/t2i_512px_clip_dimr.pth"
125
+ # checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
126
+ # nnet_2 = utils.get_nnet(**config_2.nnet)
127
+ # nnet_2 = nnet_2.to(device)
128
+ # state_dict = torch.load(checkpoint_path, map_location=device)
129
+ # nnet_2.load_state_dict(state_dict)
130
+ # nnet_2.eval()
131
 
132
  # Initialize text model.
133
  llm = "clip"
 
181
  else:
182
  assert num_of_interpolation == 3, "For arithmetic, please sample three images."
183
 
184
+ # if num_of_interpolation == 3:
185
+ # nnet = nnet_2
186
+ # else:
187
+ # nnet = nnet_1
188
+ nnet = nnet_1
189
 
190
  # Get text embeddings and tokens.
191
  _context, _token_mask, _token, _caption = get_caption(
 
302
  ]
303
 
304
  examples_2 = [
305
+ ["A dog wearing sunglasses", "red hat"],
306
  ]
307
 
308
  css = """