import gradio as gr import random import time import torch import glob import config from model import get_model_and_tokenizer model, model.prior_pipe.image_encoder = get_model_and_tokenizer(config.model_path, 'cuda', torch.bfloat16) # TODO unify/merge origin and this # TODO save & restart from (if it exists) dataframe parquet device = "cuda" import spaces import matplotlib.pyplot as plt import os import gradio as gr import pandas as pd from apscheduler.schedulers.background import BackgroundScheduler import random import time from PIL import Image # from safety_checker_improved import maybe_nsfw torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb']) import spaces start_time = time.time() ####################### Setup Model from diffusers import EulerDiscreteScheduler from PIL import Image import uuid @spaces.GPU() def generate_gpu(in_im_embs, prompt='the scene'): with torch.no_grad(): in_im_embs = in_im_embs.to('cuda') negative_image_embeds = in_im_embs[0] # model.prior_pipe.get_zero_embed() positive_image_embeds = in_im_embs[1] images = model.kandinsky_pipe( num_inference_steps=50, image_embeds=positive_image_embeds, negative_image_embeds=negative_image_embeds, guidance_scale=11, ).images[0] cond = ( model.prior_pipe.image_processor(images, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=model.prior_pipe.image_encoder.dtype, device=device) ) im_emb = model.prior_pipe.image_encoder(cond)["image_embeds"] return images, im_emb def generate(in_im_embs, ): output, im_emb = generate_gpu(in_im_embs) nsfw = False#maybe_nsfw(output.images[0]) name = str(uuid.uuid4()).replace("-", "") path = f"/tmp/{name}.png" if nsfw: gr.Warning("NSFW content detected.") # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring. return None, im_emb output.save(path) return path, im_emb ####################### @spaces.GPU() def sample_embs(prompt_embeds): latent = torch.randn(prompt_embeds.shape[0], 1, prompt_embeds.shape[-1]) if prompt_embeds.shape[1] < 8: # TODO grab as `k` arg from config prompt_embeds = torch.nn.functional.pad(prompt_embeds, [0, 0, 0, 8-prompt_embeds.shape[1]]) assert prompt_embeds.shape[1] == 8, f"The model is set to take `k`` cond image embeds but is shape {prompt_embeds.shape}" image_embeds = model(latent.to('cuda'), prompt_embeds.to('cuda')).predicted_image_embedding return image_embeds @spaces.GPU() def get_user_emb(embs, ys): positives = [e for e, ys in zip(embs, ys) if ys == 1] embs = random.sample(positives, min(8, len(positives))) if len(embs) == 0: positives = torch.zeros_like(im_emb)[None] else: positives = torch.stack(embs, 1) negs = [e for e, ys in zip(embs, ys) if ys == 0] negative_embs = random.sample(negs, min(8, len(negs))) if len(negative_embs) == 0: negatives = torch.zeros_like(im_emb)[None] else: negatives = torch.stack(negative_embs, 1) image_embeds = torch.stack([sample_embs(negatives), sample_embs(positives)]) return image_embeds def background_next_image(): global prevs_df # only let it get N (maybe 3) ahead of the user #not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]] rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]] if len(rated_rows) < 4: time.sleep(.1) # not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]] return user_id_list = set(rated_rows['latest_user_to_rate'].to_list()) for uid in user_id_list: rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]] not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]] # we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the # media. unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]] # we don't compute more after n are in the queue for them if len(unrated_from_user) >= 10: continue if len(rated_rows) < 4: continue global glob_idx glob_idx += 1 ems = rated_rows['embeddings'].to_list() ys = [i[uid][0] for i in rated_rows['user:rating'].to_list()] emz = get_user_emb(ems, ys) img, embs = generate(emz) if img: tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb']) tmp_df['paths'] = [img] tmp_df['embeddings'] = [embs.to(torch.float32).to('cpu')] tmp_df['user:rating'] = [{' ': ' '}] tmp_df['from_user_id'] = [uid] tmp_df['text'] = [''] prevs_df = pd.concat((prevs_df, tmp_df)) # we can free up storage by deleting the image if len(prevs_df) > 500: oldest_path = prevs_df.iloc[6]['paths'] if os.path.isfile(oldest_path): os.remove(oldest_path) else: # If it fails, inform the user. print("Error: %s file not found" % oldest_path) # only keep 50 images & embeddings & ips, then remove oldest besides calibrating prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:])) def pluck_img(user_id): rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) is not None for i in prevs_df.iterrows()]] ems = rated_rows['embeddings'].to_list() ys = [i[user_id][0] for i in rated_rows['user:rating'].to_list()] user_emb = get_user_emb(ems, ys) not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]] while len(not_rated_rows) == 0: not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]] time.sleep(.1) # TODO optimize this lol unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == user_id for i in not_rated_rows.iterrows()]] if len(unrated_from_user) > 0: print(unrated_from_user) # NOTE the way I've setup pandas here is so gdm horrible. TODO overhaul img = unrated_from_user['paths'].to_list()[-1] return img best_sim = -10000000 for i in not_rated_rows.iterrows(): # TODO sloppy .to but it is 3am. sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu'), -1) if len(sim) > 1: sim = sim[1] if sim.squeeze() > best_sim: best_sim = sim best_row = i[1] img = best_row['paths'] return img def next_image(calibrate_prompts, user_id): with torch.no_grad(): # once we've done so many random calibration prompts out of the full media if len(m_calibrate) - len(calibrate_prompts) < 5: cal_video = calibrate_prompts.pop(random.randint(0, len(calibrate_prompts)-1)) image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0] # we switch to just getting media by similarity. else: image = pluck_img(user_id) return image, calibrate_prompts def start(_, calibrate_prompts, user_id, request: gr.Request): user_id = int(str(time.time())[-7:].replace('.', '')) image, calibrate_prompts = next_image(calibrate_prompts, user_id) return [ gr.Button(value='👍', interactive=True), gr.Button(value='Neither (Space)', interactive=True, visible=False), gr.Button(value='👎', interactive=True), gr.Button(value='Start', interactive=False), gr.Button(value='👍 Content', interactive=True, visible=False), gr.Button(value='👍 Style', interactive=True, visible=False), image, calibrate_prompts, user_id, ] def choose(img, choice, calibrate_prompts, user_id, request: gr.Request): global prevs_df if choice == '👍': choice = [1, 1] elif choice == 'Neither (Space)': img, calibrate_prompts = next_image(calibrate_prompts, user_id) return img, calibrate_prompts elif choice == '👎': choice = [0, 0] elif choice == '👍 Style': choice = [0, 1] elif choice == '👍 Content': choice = [1, 0] else: assert False, f'choice is {choice}' # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. # TODO skip allowing rating & just continue if img is None: print('NSFW -- choice is disliked') choice = [0, 0] row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()] # if it's still in the dataframe, add the choice if len(prevs_df.loc[row_mask, 'user:rating']) > 0: prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id] else: print('Image apparently removed', img) img, calibrate_prompts = next_image(calibrate_prompts, user_id) return img, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(head=js_head, css=css) as demo: gr.Markdown('''# The Other Tiger ### Generative Recommenders for Exporation of Possible Images Explore the latent space using binary feedback. [rynmurdock.github.io](https://rynmurdock.github.io/) ''', elem_id="description") user_id = gr.State() # calibration videos -- this is a misnomer now :D calibrate_prompts = gr.State( glob.glob('image_init/*') ) def l(): return None with gr.Row(elem_id='output-image'): img = gr.Image( label='Lightning', interactive=False, elem_id="output_im", type='filepath', height=700, width=700, ) with gr.Row(equal_height=True): b3 = gr.Button(value='👎', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False) b1 = gr.Button(value='👍', interactive=False, elem_id="like") with gr.Row(equal_height=True): b6 = gr.Button(value='👍 Style', interactive=False, elem_id="dislike like", visible=False) b5 = gr.Button(value='👍 Content', interactive=False, elem_id="like dislike", visible=False) b1.click( choose, [img, b1, calibrate_prompts, user_id], [img, calibrate_prompts, ], ) b2.click( choose, [img, b2, calibrate_prompts, user_id], [img, calibrate_prompts, ], ) b3.click( choose, [img, b3, calibrate_prompts, user_id], [img, calibrate_prompts, ], ) b5.click( choose, [img, b5, calibrate_prompts, user_id], [img, calibrate_prompts, ], ) b6.click( choose, [img, b6, calibrate_prompts, user_id], [img, calibrate_prompts, ], ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, calibrate_prompts, user_id], [b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ] ) with gr.Row(): html = gr.HTML('''
You will calibrate for several images and then roam.




Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. ''') # TODO quiet logging scheduler = BackgroundScheduler() scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2) scheduler.start() # TODO shouldn't call this before gradio launch, yeah? @spaces.GPU() def encode_space(x): im = ( model.prior_pipe.image_processor(x, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=model.prior_pipe.image_encoder.dtype, device=device) ) im_emb = model.prior_pipe.image_encoder(im)["image_embeds"] return im_emb.detach().to('cpu').to(torch.float32) # NOTE: # media is moved into a random tmp folder so we need to parse filenames carefully. # do not have any cases where a file name is the same or could be `in` another filename # you also maybe can't use jpegs lmao # prep our calibration videos m_calibrate = glob.glob('image_init/*') for im in m_calibrate: tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb', 'from_user_id']) tmp_df['paths'] = [im] image = Image.open(im).convert('RGB') im_emb = encode_space(image) tmp_df['embeddings'] = [im_emb.detach().to('cpu')] tmp_df['user:rating'] = [{' ': ' '}] tmp_df['text'] = [''] # seems to break things... tmp_df['from_user_id'] = [0] tmp_df['latest_user_to_rate'] = [0] prevs_df = pd.concat((prevs_df, tmp_df)) glob_idx = 0 demo.launch(share=True,)