Spaces:
Sleeping
Sleeping
Milvusを使用したベクトル検索機能の追加と、環境変数からの設定読み込みを実装。埋め込みモデルの推論エンドポイントへのリトライ機能を追加。
Browse files- .env +5 -0
- app.py +91 -37
- requirements.txt +2 -6
.env
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
EMBEDDING_MODEL_ENDPOINT=https://osyd05gsoix24h2p.us-east-1.aws.endpoints.huggingface.cloud
|
2 |
+
ABRG_ENDPOINT=https://abrg-api-770258656166.asia-northeast1.run.app
|
3 |
+
VECTOR_SEARCH_ENDPOINT=https://in03-61f450c72e52352.serverless.gcp-us-west1.cloud.zilliz.com
|
4 |
+
VECTOR_SEARCH_TOKEN=87ae2d391f7ef8595b0f37bbde8785cb267d001284492ceaec2d0b5967ba5af5c1733a61f3e151381c908477b05d88a20696876a
|
5 |
+
VECTOR_SEARCH_COLLECTION_NAME=japanese_address
|
app.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import requests
|
3 |
import pandas as pd
|
4 |
-
import faiss
|
5 |
-
from tqdm import tqdm
|
6 |
import os
|
7 |
-
import numpy as np
|
8 |
-
from sentence_transformers import SentenceTransformer
|
9 |
-
from huggingface_hub import snapshot_download
|
10 |
from fastapi import FastAPI
|
|
|
|
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
@@ -19,15 +18,20 @@ app = FastAPI()
|
|
19 |
|
20 |
# 環境変数からHUGGING_FACE_TOKENを取得
|
21 |
HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
|
22 |
-
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
if not os.path.exists(data_dir):
|
28 |
-
snapshot_download(repo_id=repo_id, local_dir=data_dir, use_auth_token=HUGGING_FACE_TOKEN)
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# 47都道府県のリスト
|
33 |
prefs = [
|
@@ -66,30 +70,62 @@ examples = [
|
|
66 |
'少し待ってください。',
|
67 |
]
|
68 |
|
69 |
-
def init_faiss():
|
70 |
-
index = faiss.IndexFlatIP(1024)
|
71 |
-
|
72 |
-
all_addresses = []
|
73 |
-
for pref in tqdm(prefs):
|
74 |
-
with np.load(f'{data_dir}/pref/{pref}.npz') as data:
|
75 |
-
address_embeds = data['embeds']
|
76 |
-
addresses = data['addresses']
|
77 |
-
faiss.normalize_L2(address_embeds)
|
78 |
-
index.add(address_embeds)
|
79 |
-
|
80 |
-
# 後で検索結果と照合する用
|
81 |
-
all_addresses.extend(addresses.tolist()) # numpy.str_ -> str に変換される
|
82 |
-
|
83 |
-
return index, all_addresses
|
84 |
-
|
85 |
def preprocess(text):
|
86 |
text = text.replace('◯', '0')
|
87 |
return text
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
with gr.Blocks() as demo:
|
95 |
with gr.Tab("デジ庁API"):
|
@@ -141,13 +177,31 @@ with gr.Blocks() as demo:
|
|
141 |
search_button = gr.Button(value='検索', variant='primary')
|
142 |
result_dataframe = gr.Dataframe(label="検索結果")
|
143 |
|
144 |
-
index, all_addresses = init_faiss()
|
145 |
-
|
146 |
def search_address(query_address, top_k):
|
147 |
query_address = preprocess(query_address)
|
148 |
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所'])
|
152 |
return df
|
153 |
|
|
|
1 |
import gradio as gr
|
2 |
+
import time
|
3 |
import requests
|
4 |
import pandas as pd
|
|
|
|
|
5 |
import os
|
|
|
|
|
|
|
6 |
from fastapi import FastAPI
|
7 |
+
from pymilvus import MilvusClient
|
8 |
+
from dotenv import load_dotenv
|
9 |
|
10 |
+
# .envファイルを読み込む
|
11 |
+
load_dotenv()
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
|
|
18 |
|
19 |
# 環境変数からHUGGING_FACE_TOKENを取得
|
20 |
HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
|
21 |
+
EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
|
22 |
+
ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
|
23 |
|
24 |
+
VECTOR_SEARCH_ENDPOINT = os.environ.get('VECTOR_SEARCH_ENDPOINT')
|
25 |
+
VECTOR_SEARCH_TOKEN = os.environ.get('VECTOR_SEARCH_TOKEN')
|
26 |
+
VECTOR_SEARCH_COLLECTION_NAME = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME')
|
|
|
|
|
27 |
|
28 |
+
def init_milvus():
|
29 |
+
milvus_client = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
|
30 |
+
print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
|
31 |
+
|
32 |
+
return milvus_client
|
33 |
+
|
34 |
+
MILVUS_CLIENT = init_milvus()
|
35 |
|
36 |
# 47都道府県のリスト
|
37 |
prefs = [
|
|
|
70 |
'少し待ってください。',
|
71 |
]
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def preprocess(text):
|
74 |
text = text.replace('◯', '0')
|
75 |
return text
|
76 |
|
77 |
+
from enum import Enum
|
78 |
+
|
79 |
+
class InferenceEndpointErrorCode(Enum):
|
80 |
+
INVALID_STATE = 400
|
81 |
+
SERVICE_UNAVAILABLE = 503
|
82 |
+
UNKNOWN_ERROR = 520
|
83 |
+
|
84 |
+
class InferenceEndpointError(Exception):
|
85 |
+
def __init__(self, code: InferenceEndpointErrorCode, message="エラー"):
|
86 |
+
self.code = code
|
87 |
+
self.message = message
|
88 |
+
super().__init__(self.message)
|
89 |
+
|
90 |
+
def embed_via_multilingual_e5_large(query_addresses):
|
91 |
+
headers = {
|
92 |
+
"Accept" : "application/json",
|
93 |
+
"Authorization": f"Bearer {HUGGING_FACE_TOKEN}",
|
94 |
+
"Content-Type": "application/json"
|
95 |
+
}
|
96 |
+
|
97 |
+
response = requests.post(EMBEDDING_MODEL_ENDPOINT, headers=headers, json={"inputs": query_addresses})
|
98 |
+
response_json = response.json()
|
99 |
+
|
100 |
+
if 'error' in response_json:
|
101 |
+
if response_json['error'] == 'Bad Request: Invalid state':
|
102 |
+
raise InferenceEndpointError(InferenceEndpointErrorCode.INVALID_STATE, "Bad Request: Invalid state")
|
103 |
+
elif response_json['error'] == '503 Service Unavailable':
|
104 |
+
raise InferenceEndpointError(InferenceEndpointErrorCode.SERVICE_UNAVAILABLE, "Service Unavailable")
|
105 |
+
else:
|
106 |
+
raise InferenceEndpointError(InferenceEndpointErrorCode.UNKNOWN_ERROR, response_json['error'])
|
107 |
+
|
108 |
+
return response_json
|
109 |
+
|
110 |
+
def search_via_milvus(query_vector, top_k):
|
111 |
+
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} # MiniLM系はCOSINE推奨
|
112 |
+
|
113 |
+
results = MILVUS_CLIENT.search(
|
114 |
+
collection_name=VECTOR_SEARCH_COLLECTION_NAME,
|
115 |
+
data=[query_vector],
|
116 |
+
search_params=search_params,
|
117 |
+
limit=top_k,
|
118 |
+
anns_field='embedding',
|
119 |
+
output_fields=['address'],
|
120 |
+
)[0]
|
121 |
+
|
122 |
+
hits = []
|
123 |
+
for i, result in enumerate(results, start=1):
|
124 |
+
distance = result['distance']
|
125 |
+
address = result['entity'].get('address')
|
126 |
+
hits.append([i, distance, address])
|
127 |
+
|
128 |
+
return hits
|
129 |
|
130 |
with gr.Blocks() as demo:
|
131 |
with gr.Tab("デジ庁API"):
|
|
|
177 |
search_button = gr.Button(value='検索', variant='primary')
|
178 |
result_dataframe = gr.Dataframe(label="検索結果")
|
179 |
|
|
|
|
|
180 |
def search_address(query_address, top_k):
|
181 |
query_address = preprocess(query_address)
|
182 |
|
183 |
+
wait_time = 30
|
184 |
+
max_retries = 5
|
185 |
+
for attempt in range(max_retries):
|
186 |
+
try:
|
187 |
+
query_embeds = embed_via_multilingual_e5_large([query_address])
|
188 |
+
break # 成功した場合はループを抜ける
|
189 |
+
|
190 |
+
except InferenceEndpointError as e:
|
191 |
+
if e.code == InferenceEndpointErrorCode.SERVICE_UNAVAILABLE:
|
192 |
+
if attempt < max_retries - 1:
|
193 |
+
gr.Warning(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 埋め込みモデルの推論エンドポイントが起動中です。{wait_time}秒後にリトライします。", duration=wait_time)
|
194 |
+
time.sleep(wait_time) # 30秒待機
|
195 |
+
else:
|
196 |
+
raise gr.Error(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 最大リトライ回数に達しました。しばらくしてから再度実行してみてください。")
|
197 |
+
|
198 |
+
elif e.code == InferenceEndpointErrorCode.INVALID_STATE:
|
199 |
+
raise gr.Error(f"{InferenceEndpointErrorCode.INVALID_STATE}: 埋め込みモデルの推論エンドポイントが停止中です。再起動するよう管理者に問い合わせてください。")
|
200 |
+
|
201 |
+
elif e.code == InferenceEndpointErrorCode.UNKNOWN_ERROR:
|
202 |
+
raise gr.Error(f"{InferenceEndpointErrorCode.UNKNOWN_ERROR}: {e.message}")
|
203 |
+
|
204 |
+
hits = search_via_milvus(query_embeds[0], top_k)
|
205 |
df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所'])
|
206 |
return df
|
207 |
|
requirements.txt
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
-
|
3 |
gradio
|
4 |
pandas
|
5 |
numpy
|
6 |
-
faiss-cpu
|
7 |
-
sentence-transformers
|
8 |
huggingface-hub
|
9 |
-
torch==2.5.1
|
10 |
fastapi
|
11 |
-
uvicorn
|
|
|
|
|
|
|
|
1 |
gradio
|
2 |
pandas
|
3 |
numpy
|
|
|
|
|
4 |
huggingface-hub
|
|
|
5 |
fastapi
|
6 |
+
uvicorn
|
7 |
+
pymilvus
|