|
import json |
|
import os |
|
from collections import defaultdict |
|
from typing import List, Dict |
|
|
|
import faiss |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \ |
|
KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool |
|
from hfutils.operate import get_hf_fs, get_hf_client |
|
from hfutils.utils import TemporaryDirectory |
|
from realutils.metrics import siglip |
|
from imgutils.utils import ts_lru_cache |
|
|
|
from pools import quick_webp_pool |
|
|
|
siglip._REPO_ID = "deepghs/siglip_beta" |
|
_REPO_ID = 'deepghs/anime_sites_indices' |
|
|
|
hf_fs = get_hf_fs() |
|
hf_client = get_hf_client() |
|
|
|
_DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_8005009_4GB' |
|
_ALL_MODEL_NAMES = [ |
|
os.path.dirname(os.path.relpath(path, _REPO_ID)) |
|
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') |
|
] |
|
|
|
_SITE_CLS = { |
|
'danbooru': DanbooruNewestWebpDataPool, |
|
'yandere': YandeWebpDataPool, |
|
'zerochan': ZerochanWebpDataPool, |
|
'gelbooru': GelbooruWebpDataPool, |
|
'konachan': KonachanWebpDataPool, |
|
'anime_pictures': AnimePicturesWebpDataPool, |
|
'rule34': Rule34WebpDataPool, |
|
} |
|
|
|
|
|
def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: |
|
with TemporaryDirectory() as td: |
|
site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3) |
|
datapool = site_cls() |
|
datapool.batch_download_to_directory( |
|
resource_ids=ids, |
|
dst_dir=td, |
|
) |
|
|
|
retval = {} |
|
for file in os.listdir(td): |
|
id_ = int(os.path.splitext(file)[0]) |
|
image = Image.open(os.path.join(td, file)) |
|
image.load() |
|
retval[id_] = image |
|
|
|
return retval |
|
|
|
|
|
def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: |
|
_sites = defaultdict(list) |
|
for id_ in ids: |
|
site_name, num_id = id_.rsplit('_', maxsplit=1) |
|
num_id = int(num_id) |
|
_sites[site_name].append(num_id) |
|
|
|
_retval = {} |
|
for site_name, site_ids in _sites.items(): |
|
_retval.update({ |
|
f'{site_name}_{id_}': image |
|
for id_, image in _get_from_ids(site_name, site_ids).items() |
|
}) |
|
return _retval |
|
|
|
|
|
@ts_lru_cache(maxsize=3) |
|
def _get_index_info(repo_id: str, model_name: str): |
|
image_ids = np.load(hf_client.hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type='model', |
|
filename=f'{model_name}/ids.npy', |
|
)) |
|
knn_index = faiss.read_index(hf_client.hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type='model', |
|
filename=f'{model_name}/knn.index', |
|
)) |
|
|
|
config = json.loads(open(hf_client.hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type='model', |
|
filename=f'{model_name}/infos.json', |
|
)).read())["index_param"] |
|
faiss.ParameterSpace().set_index_parameters(knn_index, config) |
|
return image_ids, knn_index |
|
|
|
|
|
def search(model_name: str, img_input, str_input: str, n_neighbours: int): |
|
images_ids, knn_index = _get_index_info(_REPO_ID, model_name) |
|
|
|
if str_input == "": |
|
embeddings = siglip.get_siglip_image_embedding( |
|
img_input, |
|
model_name="smilingwolf/siglip_swinv2_base_2025_02_22_18h56m54s", |
|
fmt="embeddings", |
|
) |
|
else: |
|
embeddings = siglip.get_siglip_text_embedding( |
|
str_input, |
|
model_name="smilingwolf/siglip_swinv2_base_2025_02_22_18h56m54s", |
|
fmt="embeddings", |
|
) |
|
|
|
|
|
|
|
|
|
dists, indexes = knn_index.search(embeddings, k=n_neighbours) |
|
neighbours_ids = images_ids[indexes][0] |
|
|
|
captions = [] |
|
images = [] |
|
ids_to_images = _get_from_raw_ids(neighbours_ids) |
|
for image_id, dist in zip(neighbours_ids, dists[0]): |
|
if image_id in ids_to_images: |
|
images.append(ids_to_images[image_id]) |
|
captions.append(f"{image_id}/{dist:.2f}") |
|
|
|
return list(zip(images, captions)) |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(type="pil", image_mode="RGBA", label="Image input") |
|
str_input = gr.Textbox(label="Text input (leave empty to use image input)") |
|
with gr.Column(): |
|
with gr.Row(): |
|
n_model = gr.Dropdown( |
|
choices=_ALL_MODEL_NAMES, |
|
value=_DEFAULT_MODEL_NAME, |
|
label='Index to Use', |
|
) |
|
with gr.Row(): |
|
n_neighbours = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
value=20, |
|
step=1, |
|
label="# of images", |
|
) |
|
find_btn = gr.Button("Find similar images") |
|
|
|
with gr.Row(): |
|
similar_images = gr.Gallery(label="Similar images", columns=[5]) |
|
|
|
find_btn.click( |
|
fn=search, |
|
inputs=[ |
|
n_model, |
|
img_input, |
|
str_input, |
|
n_neighbours, |
|
], |
|
outputs=[similar_images], |
|
) |
|
|
|
demo.queue().launch() |
|
|