|
import json |
|
import os |
|
from collections import defaultdict |
|
from functools import lru_cache |
|
from typing import List, Dict |
|
|
|
import faiss |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \ |
|
KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool |
|
from hfutils.operate import get_hf_fs, get_hf_client |
|
from hfutils.utils import TemporaryDirectory |
|
from imgutils.tagging import wd14 |
|
|
|
_REPO_ID = 'deepghs/index_experiments' |
|
|
|
hf_fs = get_hf_fs() |
|
hf_client = get_hf_client() |
|
|
|
_DEFAULT_MODEL_NAME = 'SwinV2_v3_dgzyka_23325111_8GB' |
|
_ALL_MODEL_NAMES = [ |
|
os.path.dirname(os.path.relpath(path, _REPO_ID)) |
|
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') |
|
] |
|
|
|
_SITE_CLS = { |
|
'danbooru': DanbooruNewestWebpDataPool, |
|
'yandere': YandeWebpDataPool, |
|
'zerochan': ZerochanWebpDataPool, |
|
'gelbooru': GelbooruWebpDataPool, |
|
'konachan': KonachanWebpDataPool, |
|
'anime_pictures': AnimePicturesWebpDataPool, |
|
'rule34': Rule34WebpDataPool, |
|
} |
|
|
|
|
|
def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: |
|
with TemporaryDirectory() as td: |
|
datapool = _SITE_CLS[site_name]() |
|
datapool.batch_download_to_directory( |
|
resource_ids=ids, |
|
dst_dir=td, |
|
) |
|
|
|
retval = {} |
|
for file in os.listdir(td): |
|
id_ = int(os.path.splitext(file)[0]) |
|
image = Image.open(os.path.join(td, file)) |
|
image.load() |
|
retval[id_] = image |
|
|
|
return retval |
|
|
|
|
|
def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: |
|
_sites = defaultdict(list) |
|
for id_ in ids: |
|
site_name, num_id = id_.rsplit('_', maxsplit=1) |
|
num_id = int(num_id) |
|
_sites[site_name].append(num_id) |
|
|
|
_retval = {} |
|
for site_name, site_ids in _sites.items(): |
|
_retval.update({ |
|
f'{site_name}_{id_}': image |
|
for id_, image in _get_from_ids(site_name, site_ids).items() |
|
}) |
|
return _retval |
|
|
|
|
|
@lru_cache(maxsize=3) |
|
def _get_index_info(repo_id: str, model_name: str): |
|
image_ids = np.load(hf_client.hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type='model', |
|
filename=f'{model_name}/ids.npy', |
|
)) |
|
knn_index = faiss.read_index(hf_client.hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type='model', |
|
filename=f'{model_name}/knn.index', |
|
)) |
|
|
|
config = json.loads(open(hf_client.hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type='model', |
|
filename=f'{model_name}/infos.json', |
|
)).read())["index_param"] |
|
faiss.ParameterSpace().set_index_parameters(knn_index, config) |
|
return image_ids, knn_index |
|
|
|
|
|
def search(model_name: str, img_input, n_neighbours: int): |
|
images_ids, knn_index = _get_index_info(_REPO_ID, model_name) |
|
embeddings = wd14.get_wd14_tags( |
|
img_input, |
|
model_name="SwinV2_v3", |
|
fmt="embedding", |
|
) |
|
embeddings = np.expand_dims(embeddings, 0) |
|
faiss.normalize_L2(embeddings) |
|
|
|
dists, indexes = knn_index.search(embeddings, k=n_neighbours) |
|
neighbours_ids = images_ids[indexes][0] |
|
|
|
captions = [] |
|
images = [] |
|
ids_to_images = _get_from_raw_ids(neighbours_ids) |
|
for image_id, dist in zip(neighbours_ids, dists[0]): |
|
if image_id in ids_to_images: |
|
images.append(ids_to_images[image_id]) |
|
captions.append(f"{image_id}/{dist:.2f}") |
|
|
|
return list(zip(images, captions)) |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(type="pil", label="Input") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
n_model = gr.Dropdown( |
|
choices=_ALL_MODEL_NAMES, |
|
value=_DEFAULT_MODEL_NAME, |
|
label='Index to Use', |
|
) |
|
with gr.Row(): |
|
n_neighbours = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
value=20, |
|
step=1, |
|
label="# of images", |
|
) |
|
find_btn = gr.Button("Find similar images") |
|
|
|
with gr.Row(): |
|
similar_images = gr.Gallery(label="Similar images", columns=[5]) |
|
|
|
find_btn.click( |
|
fn=search, |
|
inputs=[ |
|
n_model, |
|
img_input, |
|
n_neighbours, |
|
], |
|
outputs=[similar_images], |
|
) |
|
|
|
demo.queue().launch() |
|
|