rynmurdock commited on
Commit
d83af99
Β·
1 Parent(s): 07d8d5d
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import random
4
  import time
5
  import torch
6
-
7
 
8
  import config
9
  from model import get_model_and_tokenizer
@@ -137,8 +137,8 @@ def background_next_image():
137
  if len(unrated_from_user) >= 10:
138
  continue
139
 
140
- if len(rated_rows) < 5:
141
- continue
142
 
143
  global glob_idx
144
  glob_idx += 1
@@ -170,11 +170,13 @@ def background_next_image():
170
 
171
  def pluck_img(user_id):
172
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]]
 
173
  ems = rated_rows['embeddings'].to_list()
174
  ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()]
175
  user_emb = get_user_emb(ems, ys)
176
 
177
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
 
178
  while len(not_rated_rows) == 0:
179
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
180
  time.sleep(.1)
@@ -182,8 +184,10 @@ def pluck_img(user_id):
182
 
183
  unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == user_id for i in not_rated_rows.iterrows()]]
184
  if len(unrated_from_user) > 0:
 
185
  # NOTE the way I've setup pandas here is so gdm horrible. TODO overhaul
186
- img = unrated_from_user['paths'].to_list()[0]
 
187
 
188
  best_sim = -10000000
189
  for i in not_rated_rows.iterrows():
@@ -201,7 +205,7 @@ def next_image(calibrate_prompts, user_id):
201
  if len(calibrate_prompts) > 0:
202
  cal_video = calibrate_prompts.pop(0)
203
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
204
- return image, calibrate_prompts,
205
  else:
206
  image = pluck_img(user_id)
207
  return image, calibrate_prompts
@@ -211,9 +215,6 @@ def next_image(calibrate_prompts, user_id):
211
 
212
 
213
 
214
-
215
-
216
-
217
  def start(_, calibrate_prompts, user_id, request: gr.Request):
218
  user_id = int(str(time.time())[-7:].replace('.', ''))
219
  image, calibrate_prompts = next_image(calibrate_prompts, user_id)
@@ -227,19 +228,17 @@ def start(_, calibrate_prompts, user_id, request: gr.Request):
227
  image,
228
  calibrate_prompts,
229
  user_id,
230
-
231
  ]
232
 
233
 
234
  def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
235
  global prevs_df
236
 
237
-
238
  if choice == 'πŸ‘':
239
  choice = [1, 1]
240
  elif choice == 'Neither (Space)':
241
- img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
242
- return img, calibrate_prompts,
243
  elif choice == 'πŸ‘Ž':
244
  choice = [0, 0]
245
  elif choice == 'πŸ‘ Style':
@@ -251,7 +250,6 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
251
 
252
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
253
  # TODO skip allowing rating & just continue
254
-
255
  if img is None:
256
  print('NSFW -- choice is disliked')
257
  choice = [0, 0]
@@ -260,8 +258,10 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
260
  # if it's still in the dataframe, add the choice
261
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
262
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
263
- print(row_mask, prevs_df.loc[row_mask, 'latest_user_to_rate'], [user_id])
264
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
 
 
 
265
  img, calibrate_prompts = next_image(calibrate_prompts, user_id)
266
  return img, calibrate_prompts
267
 
@@ -330,19 +330,7 @@ Explore the latent space without text prompts based on your preferences. [rynmur
330
  ''', elem_id="description")
331
  user_id = gr.State()
332
  # calibration videos -- this is a misnomer now :D
333
- calibrate_prompts = [
334
- './5o.png',
335
- './2o.png',
336
- './6o.png',
337
- './7o.png',
338
- './1o.png',
339
- './8o.png',
340
- './3o.png',
341
- './4o.png',
342
- './10o.png',
343
- './9o.png',
344
- ]
345
- calibrate_prompts = gr.State(['image_init/'+c for c in calibrate_prompts])
346
  def l():
347
  return None
348
 
@@ -424,34 +412,25 @@ def encode_space(x):
424
  im_emb = model.prior_pipe.image_encoder(im)["image_embeds"]
425
  return im_emb.detach().to('cpu').to(torch.float32)
426
 
 
 
 
 
 
427
  # prep our calibration videos
428
- m_calibrate = [ # DO NOT NAME THESE PNGs JUST NUMBERS! apparently we assign images by number
429
- ('./1o.png', 'describe the scene: omens in the suburbs'),
430
- ('./2o.png', 'describe the scene: geometric abstract art of a windmill'),
431
- ('./3o.png', 'describe the scene: memento mori'),
432
- ('./4o.png', 'describe the scene: a green plate with anespresso'),
433
- ('./5o.png', '5 '),
434
- ('./6o.png', '6 '),
435
- ('./7o.png', '7 '),
436
- ('./8o.png', '8 '),
437
- ('./9o.png', '9 '),
438
- ('./10o.png', '10 '),
439
- ]
440
- m_calibrate = [('image_init/'+c[0], c[1]) for c in m_calibrate]
441
- for im, txt in m_calibrate:
442
- tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
443
  tmp_df['paths'] = [im]
444
  image = Image.open(im).convert('RGB')
445
  im_emb = encode_space(image)
446
 
447
  tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
448
  tmp_df['user:rating'] = [{' ': ' '}]
449
- tmp_df['text'] = [txt]
 
 
450
  prevs_df = pd.concat((prevs_df, tmp_df))
451
 
452
  glob_idx = 0
453
  demo.launch(share=True,)
454
-
455
-
456
- # TODO interface is shifted -- auto-resize images to all be the same.
457
-
 
3
  import random
4
  import time
5
  import torch
6
+ import glob
7
 
8
  import config
9
  from model import get_model_and_tokenizer
 
137
  if len(unrated_from_user) >= 10:
138
  continue
139
 
140
+ if len(rated_rows) < 4:
141
+ continue
142
 
143
  global glob_idx
144
  glob_idx += 1
 
170
 
171
  def pluck_img(user_id):
172
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]]
173
+ print(rated_rows)
174
  ems = rated_rows['embeddings'].to_list()
175
  ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()]
176
  user_emb = get_user_emb(ems, ys)
177
 
178
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
179
+ print(not_rated_rows)
180
  while len(not_rated_rows) == 0:
181
  not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
182
  time.sleep(.1)
 
184
 
185
  unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == user_id for i in not_rated_rows.iterrows()]]
186
  if len(unrated_from_user) > 0:
187
+ print(unrated_from_user)
188
  # NOTE the way I've setup pandas here is so gdm horrible. TODO overhaul
189
+ img = unrated_from_user['paths'].to_list()[-1]
190
+ return img
191
 
192
  best_sim = -10000000
193
  for i in not_rated_rows.iterrows():
 
205
  if len(calibrate_prompts) > 0:
206
  cal_video = calibrate_prompts.pop(0)
207
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
208
+ return image, calibrate_prompts
209
  else:
210
  image = pluck_img(user_id)
211
  return image, calibrate_prompts
 
215
 
216
 
217
 
 
 
 
218
  def start(_, calibrate_prompts, user_id, request: gr.Request):
219
  user_id = int(str(time.time())[-7:].replace('.', ''))
220
  image, calibrate_prompts = next_image(calibrate_prompts, user_id)
 
228
  image,
229
  calibrate_prompts,
230
  user_id,
 
231
  ]
232
 
233
 
234
  def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
235
  global prevs_df
236
 
 
237
  if choice == 'πŸ‘':
238
  choice = [1, 1]
239
  elif choice == 'Neither (Space)':
240
+ img, calibrate_prompts = next_image(calibrate_prompts, user_id)
241
+ return img, calibrate_prompts
242
  elif choice == 'πŸ‘Ž':
243
  choice = [0, 0]
244
  elif choice == 'πŸ‘ Style':
 
250
 
251
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
252
  # TODO skip allowing rating & just continue
 
253
  if img is None:
254
  print('NSFW -- choice is disliked')
255
  choice = [0, 0]
 
258
  # if it's still in the dataframe, add the choice
259
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
260
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
 
261
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
262
+ else:
263
+ print('Image apparently removed', img)
264
+ breakpoint()
265
  img, calibrate_prompts = next_image(calibrate_prompts, user_id)
266
  return img, calibrate_prompts
267
 
 
330
  ''', elem_id="description")
331
  user_id = gr.State()
332
  # calibration videos -- this is a misnomer now :D
333
+ calibrate_prompts = gr.State( [l for l in random.sample(glob.glob('image_init/*'), k=8)] )
 
 
 
 
 
 
 
 
 
 
 
 
334
  def l():
335
  return None
336
 
 
412
  im_emb = model.prior_pipe.image_encoder(im)["image_embeds"]
413
  return im_emb.detach().to('cpu').to(torch.float32)
414
 
415
+ # NOTE:
416
+ # media is moved into a random tmp folder so we need to parse filenames carefully.
417
+ # do not have any cases where a file name is the same or could be `in` another filename
418
+ # you also can't use jpegs lmao
419
+
420
  # prep our calibration videos
421
+ m_calibrate = glob.glob('image_init/*')
422
+ for im in m_calibrate:
423
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb', 'from_user_id'])
 
 
 
 
 
 
 
 
 
 
 
 
424
  tmp_df['paths'] = [im]
425
  image = Image.open(im).convert('RGB')
426
  im_emb = encode_space(image)
427
 
428
  tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
429
  tmp_df['user:rating'] = [{' ': ' '}]
430
+ tmp_df['text'] = ['']
431
+ # tmp_df['from_user_id'] = [0]
432
+ # tmp_df['latest_user_to_rate'] = [0]
433
  prevs_df = pd.concat((prevs_df, tmp_df))
434
 
435
  glob_idx = 0
436
  demo.launch(share=True,)
 
 
 
 
image_init/10o.png CHANGED

Git LFS Details

  • SHA256: c4cc3479937b13dd2b3d66479a44becd24f5f7e8262e47b93a6ec94883733ffe
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB

Git LFS Details

  • SHA256: 16edbcb1b5cab0d32244ac9faac2f9e25724e00532c867f0b1cd4808cacc1054
  • Pointer size: 131 Bytes
  • Size of remote file: 388 kB
image_init/1o.png CHANGED

Git LFS Details

  • SHA256: a92dd4a8836ae42131fec755874ee1c1e0dda4679835ef1b5b86b565976090de
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB

Git LFS Details

  • SHA256: 15faf7373356cbcc59531da9d02ffe3642826c156409a4a8d88904aca295ebba
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB
image_init/2o.png CHANGED

Git LFS Details

  • SHA256: 15faf7373356cbcc59531da9d02ffe3642826c156409a4a8d88904aca295ebba
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB

Git LFS Details

  • SHA256: 5783269bf3847cd08bbf312b2cb53b241a29345564db8e5fe04235b664ee7d5c
  • Pointer size: 132 Bytes
  • Size of remote file: 5.49 MB
image_init/3o.png CHANGED

Git LFS Details

  • SHA256: 5808921156fd63d4f03c9dbc077fc719cb559ffa279385eb555f8bcf13bd0775
  • Pointer size: 131 Bytes
  • Size of remote file: 430 kB

Git LFS Details

  • SHA256: 9070c259c85fc86cce84c1d5ae951588fcc331210350d75601ed89d114b33ec7
  • Pointer size: 132 Bytes
  • Size of remote file: 2 MB
image_init/4o.png CHANGED

Git LFS Details

  • SHA256: 27c35d730f0b67039915ae14a0ea07bd2425bce04525cb88672428177e7b0117
  • Pointer size: 131 Bytes
  • Size of remote file: 956 kB

Git LFS Details

  • SHA256: dffbc9cebbd743c2cdbda578be45c8114b6528d02ab8f71b7cdc057ff84b26a4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.98 MB
image_init/5o.png CHANGED

Git LFS Details

  • SHA256: 73875f50f84790970849277a5c6ed534a94c15e3b41dc2dffc38f01cb45cdd99
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB

Git LFS Details

  • SHA256: 6683ea703e4f134185a7100ef3e78dac04041340584494799cba5d925fb228df
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
image_init/7o.png CHANGED

Git LFS Details

  • SHA256: a2919c8daed335927f5b6cf9b46016bf1717952b9a475db3eee8a1cd36b89d56
  • Pointer size: 132 Bytes
  • Size of remote file: 1.98 MB

Git LFS Details

  • SHA256: 76813201ae0f9f845ad8c75c500f72cfa63f2b5d59e6f95ceff3288d0303e1c3
  • Pointer size: 131 Bytes
  • Size of remote file: 469 kB
image_init/9o.png CHANGED

Git LFS Details

  • SHA256: ed521b4950bf0793040808daada124d41503d834fe3ea97a34cb46d71b9c2d24
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB

Git LFS Details

  • SHA256: 3b687042ce41cf61aa112473c276753992571f15be1e3526479139394d8ed49d
  • Pointer size: 131 Bytes
  • Size of remote file: 463 kB
requirements.txt CHANGED
@@ -14,4 +14,5 @@ peft
14
  imageio
15
  apscheduler
16
  pandas
17
- av
 
 
14
  imageio
15
  apscheduler
16
  pandas
17
+ av
18
+ glob2