import gradio as gr from Lex import * ''' lex = Lexica(query="man woman fire snow").images() ''' from PIL import Image import imagehash import requests from zipfile import ZipFile from time import sleep sleep_time = 0.5 hash_func_name = list(filter(lambda x: x.endswith("hash") and "hex" not in x ,dir(imagehash))) hash_func_name = ['average_hash', 'colorhash', 'dhash', 'phash', 'whash', 'crop_resistant_hash',] def min_dim_to_size(img, size = 512): h, w = img.size ratio = size / max(h, w) h, w = map(lambda x: int(x * ratio), [h, w]) return ( ratio ,img.resize((h, w)) ) #ratio_size = 512 #ratio, img_rs = min_dim_to_size(img, ratio_size) def image_click(images, evt: gr.SelectData): img_selected = images[evt.index] return images[evt.index]['name'] def swap_gallery(im, images, func_name): #### name data is_file #print(images[0].keys()) if im is None: return list(map(lambda x: x["name"], images)) hash_func = getattr(imagehash, func_name) im_hash = hash_func(Image.fromarray(im)) t2_list = sorted(images, key = lambda imm: hash_func(Image.open(imm["name"])) - im_hash, reverse = False) return list(map(lambda x: x["name"], t2_list)) def lexica(prompt, limit_size = 128, ratio_size = 256 + 128): lex = Lexica(query=prompt).images() lex = lex[:limit_size] lex = list(map(lambda x: x.replace("full_jpg", "sm2"), lex)) lex_ = [] for ele in lex: try: im = Image.open( requests.get(ele, stream = True).raw ) lex_.append(im) except: print("err") sleep(sleep_time) assert lex_ lex = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_)) return lex def enterpix(prompt, limit_size = 100, ratio_size = 256 + 128, use_key = "bigThumbnailUrl"): resp = requests.post( url = "https://www.enterpix.app/enterpix/v1/image/prompt-search", data= { "length": limit_size, "platform": "stable-diffusion,midjourney", "prompt": prompt, "start": 0 } ) resp = resp.json() resp = list(map(lambda x: x[use_key], resp["images"])) lex_ = [] for ele in resp: try: im = Image.open( requests.get(ele, stream = True).raw ) lex_.append(im) except: print("err") sleep(sleep_time) assert lex_ resp = list(map(lambda x: min_dim_to_size(x, ratio_size)[1], lex_)) return resp def search(prompt, search_name, im, func_name): if search_name == "lexica": im_l = lexica(prompt) else: im_l = enterpix(prompt) if im is None: return im_l hash_func = getattr(imagehash, func_name) im_hash = hash_func(Image.fromarray(im)) t2_list = sorted(im_l, key = lambda imm: hash_func(imm) - im_hash, reverse = False) return t2_list #return list(map(lambda x: x["name"], t2_list)) def zip_ims(g): from uuid import uuid1 if g is None: return None l = list(map(lambda x: x["name"], g)) if not l: return None zip_file_name ="tmp.zip" with ZipFile(zip_file_name ,"w") as zipObj: for ele in l: zipObj.write(ele, "{}.png".format(uuid1())) #zipObj.write(file2.name, "file2") return zip_file_name with gr.Blocks(css="custom.css") as demo: title = gr.HTML( """

SD StableDiffusion Search by Prompt order by Image Hash

""", elem_id="title", ) with gr.Row(): with gr.Column(): with gr.Row(): search_func_name = gr.Radio(choices=["lexica", "enterpix"], value="lexica", label="Search by", elem_id="search_radio") with gr.Row(): #inputs = gr.Textbox(label = 'Enter prompt to search Lexica.art') inputs = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=20, min_width = 256, placeholder="Enter prompt to search", elem_id="prompt") #gr.Slider(label='Number of images ', minimum = 4, maximum = 20, step = 1, value = 4)] text_button = gr.Button("Retrieve Images", elem_id="run_button") i = gr.Image(elem_id="result-image", label = "Image upload or selected", height = 768 - 256 - 32) with gr.Row(): with gr.Tab(label = "Download"): zip_button = gr.Button("Zip Images to Download", elem_id="zip_button") downloads = gr.File(label = "Image zipped", elem_id = "zip_file") with gr.Column(): title = gr.Markdown( value="### Click on a Image in the gallery to select it, and the grid order will change", visible=True, elem_id="selected_model", ) order_func_name = gr.Radio(choices=hash_func_name, value=hash_func_name[0], label="Order by", elem_id="order_radio") outputs = gr.Gallery(lable='Output gallery', elem_id="gallery",).style(grid=5,height=768 + 64 + 32, allow_preview=False, label = "retrieve Images") #gr.Dataframe(label='prompts for corresponding images')] with gr.Row(): gr.Examples( [ ["chinese zodiac signs", "lexica", "images/chinese_zodiac_signs.png", "average_hash"], ["trending digital art", "lexica", "images/trending_digital_art.png", "colorhash"], ["masterpiece, best quality, 1girl, solo, crop top, denim shorts, choker, (graffiti:1.5), paint splatter, arms behind back, against wall, looking at viewer, armband, thigh strap, paint on body, head tilt, bored, multicolored hair, aqua eyes, headset,", "lexica", "images/yuzu_girl0.png", "average_hash"], ["beautiful home", "enterpix", "images/beautiful_home.png", "whash"], ["interior design of living room", "enterpix", "images/interior_design_of_living_room.png", "whash"], ["1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt", "enterpix", "images/waifu_girl0.png", "phash"], ], inputs = [inputs, search_func_name, i, order_func_name], label = "Examples" ) #outputs.select(image_click, outputs, i, _js="(x) => x.splice(0,x.length)") outputs.select(image_click, outputs, i,) i.change( fn=swap_gallery, inputs=[i, outputs, order_func_name], outputs=outputs, queue=False ) order_func_name.change( fn=swap_gallery, inputs=[i, outputs, order_func_name], outputs=outputs, queue=False ) #### gr.Textbox().submit().success() ### lexica #text_button.click(lexica, inputs=inputs, outputs=outputs) ### enterpix #text_button.click(enterpix, inputs=inputs, outputs=outputs) text_button.click(search, inputs=[inputs, search_func_name, i, order_func_name], outputs=outputs) zip_button.click( zip_ims, inputs = outputs, outputs=downloads ) demo.launch("0.0.0.0")