SmilingWolf commited on
Commit
5be0a61
·
verified ·
1 Parent(s): c2a197a

Add text support

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +22 -11
  3. requirements.txt +2 -1
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Search Anime Image By Image
3
  emoji: 👁
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: true
10
  license: openrail
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Search Anime Image By Image or Text
3
  emoji: 👁
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: true
10
  license: openrail
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -12,10 +12,11 @@ from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, Gelbo
12
  KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool
13
  from hfutils.operate import get_hf_fs, get_hf_client
14
  from hfutils.utils import TemporaryDirectory
15
- from imgutils.tagging import wd14
16
 
17
  from pools import quick_webp_pool
18
 
 
19
  _REPO_ID = 'deepghs/anime_sites_indices'
20
 
21
  hf_fs = get_hf_fs()
@@ -95,15 +96,24 @@ def _get_index_info(repo_id: str, model_name: str):
95
  return image_ids, knn_index
96
 
97
 
98
- def search(model_name: str, img_input, n_neighbours: int):
99
  images_ids, knn_index = _get_index_info(_REPO_ID, model_name)
100
- embeddings = wd14.get_wd14_tags(
101
- img_input,
102
- model_name="SwinV2_v3",
103
- fmt="embedding",
104
- )
105
- embeddings = np.expand_dims(embeddings, 0)
106
- faiss.normalize_L2(embeddings)
 
 
 
 
 
 
 
 
 
107
 
108
  dists, indexes = knn_index.search(embeddings, k=n_neighbours)
109
  neighbours_ids = images_ids[indexes][0]
@@ -123,8 +133,8 @@ if __name__ == "__main__":
123
  with gr.Blocks() as demo:
124
  with gr.Row():
125
  with gr.Column():
126
- img_input = gr.Image(type="pil", label="Input")
127
-
128
  with gr.Column():
129
  with gr.Row():
130
  n_model = gr.Dropdown(
@@ -150,6 +160,7 @@ if __name__ == "__main__":
150
  inputs=[
151
  n_model,
152
  img_input,
 
153
  n_neighbours,
154
  ],
155
  outputs=[similar_images],
 
12
  KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool
13
  from hfutils.operate import get_hf_fs, get_hf_client
14
  from hfutils.utils import TemporaryDirectory
15
+ from realutils.metrics import siglip
16
 
17
  from pools import quick_webp_pool
18
 
19
+ siglip._REPO_ID = "SmilingWolf/swinv2_siglip_beta"
20
  _REPO_ID = 'deepghs/anime_sites_indices'
21
 
22
  hf_fs = get_hf_fs()
 
96
  return image_ids, knn_index
97
 
98
 
99
+ def search(model_name: str, img_input, str_input: str, n_neighbours: int):
100
  images_ids, knn_index = _get_index_info(_REPO_ID, model_name)
101
+
102
+ if str_input == "":
103
+ embeddings = siglip.get_siglip_image_embedding(
104
+ img_input,
105
+ model_name="smilingwolf/siglip_swinv2_base_2025_02_08_13h25m57s",
106
+ fmt="embeddings",
107
+ )
108
+ else:
109
+ embeddings = siglip.get_siglip_text_embedding(
110
+ str_input,
111
+ model_name="smilingwolf/siglip_swinv2_base_2025_02_08_13h25m57s",
112
+ fmt="embeddings",
113
+ )
114
+
115
+ # In the model, the "embeddings" output node is already normalized.
116
+ # Ask for the "encodings" output if you want the raw logits
117
 
118
  dists, indexes = knn_index.search(embeddings, k=n_neighbours)
119
  neighbours_ids = images_ids[indexes][0]
 
133
  with gr.Blocks() as demo:
134
  with gr.Row():
135
  with gr.Column():
136
+ img_input = gr.Image(type="pil", label="Image input")
137
+ str_input = gr.Textbox(label="Text input (leave empty to use image input)")
138
  with gr.Column():
139
  with gr.Row():
140
  n_model = gr.Dropdown(
 
160
  inputs=[
161
  n_model,
162
  img_input,
163
+ str_input,
164
  n_neighbours,
165
  ],
166
  outputs=[similar_images],
requirements.txt CHANGED
@@ -2,5 +2,6 @@ pillow>=9.0.0
2
  faiss-cpu
3
  dghs-imgutils
4
  onnxruntime
5
- gradio==5.5.0
6
  cheesechaser>=0.1.6
 
 
2
  faiss-cpu
3
  dghs-imgutils
4
  onnxruntime
5
+ gradio==5.16.0
6
  cheesechaser>=0.1.6
7
+ dghs-realutils