matsuap commited on
Commit
ec20668
·
1 Parent(s): b0f617e

Milvusを使用したベクトル検索機能の追加と、環境変数からの設定読み込みを実装。埋め込みモデルの推論エンドポイントへのリトライ機能を追加。

Browse files
Files changed (3) hide show
  1. .env +5 -0
  2. app.py +91 -37
  3. requirements.txt +2 -6
.env ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ EMBEDDING_MODEL_ENDPOINT=https://osyd05gsoix24h2p.us-east-1.aws.endpoints.huggingface.cloud
2
+ ABRG_ENDPOINT=https://abrg-api-770258656166.asia-northeast1.run.app
3
+ VECTOR_SEARCH_ENDPOINT=https://in03-61f450c72e52352.serverless.gcp-us-west1.cloud.zilliz.com
4
+ VECTOR_SEARCH_TOKEN=87ae2d391f7ef8595b0f37bbde8785cb267d001284492ceaec2d0b5967ba5af5c1733a61f3e151381c908477b05d88a20696876a
5
+ VECTOR_SEARCH_COLLECTION_NAME=japanese_address
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import gradio as gr
 
2
  import requests
3
  import pandas as pd
4
- import faiss
5
- from tqdm import tqdm
6
  import os
7
- import numpy as np
8
- from sentence_transformers import SentenceTransformer
9
- from huggingface_hub import snapshot_download
10
  from fastapi import FastAPI
 
 
11
 
12
- CUSTOM_PATH = "/gradio"
 
13
 
14
  app = FastAPI()
15
 
@@ -19,15 +18,20 @@ app = FastAPI()
19
 
20
  # 環境変数からHUGGING_FACE_TOKENを取得
21
  HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
22
- ABRG_ENDPOINT = 'https://abrg-api-770258656166.asia-northeast1.run.app'
 
23
 
24
- repo_id = 'AtPeak/japanese-address-machiaza-vector'
25
- # repo_id = 'AtPeak/japanese-address-resident-number-vector'
26
- data_dir = 'embeddings'
27
- if not os.path.exists(data_dir):
28
- snapshot_download(repo_id=repo_id, local_dir=data_dir, use_auth_token=HUGGING_FACE_TOKEN)
29
 
30
- model = SentenceTransformer('intfloat/multilingual-e5-large', device='cuda')
 
 
 
 
 
 
31
 
32
  # 47都道府県のリスト
33
  prefs = [
@@ -66,30 +70,62 @@ examples = [
66
  '少し待ってください。',
67
  ]
68
 
69
- def init_faiss():
70
- index = faiss.IndexFlatIP(1024)
71
-
72
- all_addresses = []
73
- for pref in tqdm(prefs):
74
- with np.load(f'{data_dir}/pref/{pref}.npz') as data:
75
- address_embeds = data['embeds']
76
- addresses = data['addresses']
77
- faiss.normalize_L2(address_embeds)
78
- index.add(address_embeds)
79
-
80
- # 後で検索結果と照合する用
81
- all_addresses.extend(addresses.tolist()) # numpy.str_ -> str に変換される
82
-
83
- return index, all_addresses
84
-
85
  def preprocess(text):
86
  text = text.replace('◯', '0')
87
  return text
88
 
89
- def search_via_faiss(query_embed, top_k):
90
- faiss.normalize_L2(query_embed)
91
- D, I = index.search(query_embed, top_k)
92
- return [(k, d, all_addresses[i]) for k, (d, i) in enumerate(zip(D[0], I[0]), start=1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  with gr.Blocks() as demo:
95
  with gr.Tab("デジ庁API"):
@@ -141,13 +177,31 @@ with gr.Blocks() as demo:
141
  search_button = gr.Button(value='検索', variant='primary')
142
  result_dataframe = gr.Dataframe(label="検索結果")
143
 
144
- index, all_addresses = init_faiss()
145
-
146
  def search_address(query_address, top_k):
147
  query_address = preprocess(query_address)
148
 
149
- query_embed = model.encode([query_address], convert_to_numpy=True)
150
- hits = search_via_faiss(query_embed, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所'])
152
  return df
153
 
 
1
  import gradio as gr
2
+ import time
3
  import requests
4
  import pandas as pd
 
 
5
  import os
 
 
 
6
  from fastapi import FastAPI
7
+ from pymilvus import MilvusClient
8
+ from dotenv import load_dotenv
9
 
10
+ # .envファイルを読み込む
11
+ load_dotenv()
12
 
13
  app = FastAPI()
14
 
 
18
 
19
  # 環境変数からHUGGING_FACE_TOKENを取得
20
  HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
21
+ EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
22
+ ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
23
 
24
+ VECTOR_SEARCH_ENDPOINT = os.environ.get('VECTOR_SEARCH_ENDPOINT')
25
+ VECTOR_SEARCH_TOKEN = os.environ.get('VECTOR_SEARCH_TOKEN')
26
+ VECTOR_SEARCH_COLLECTION_NAME = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME')
 
 
27
 
28
+ def init_milvus():
29
+ milvus_client = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
30
+ print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
31
+
32
+ return milvus_client
33
+
34
+ MILVUS_CLIENT = init_milvus()
35
 
36
  # 47都道府県のリスト
37
  prefs = [
 
70
  '少し待ってください。',
71
  ]
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def preprocess(text):
74
  text = text.replace('◯', '0')
75
  return text
76
 
77
+ from enum import Enum
78
+
79
+ class InferenceEndpointErrorCode(Enum):
80
+ INVALID_STATE = 400
81
+ SERVICE_UNAVAILABLE = 503
82
+ UNKNOWN_ERROR = 520
83
+
84
+ class InferenceEndpointError(Exception):
85
+ def __init__(self, code: InferenceEndpointErrorCode, message="エラー"):
86
+ self.code = code
87
+ self.message = message
88
+ super().__init__(self.message)
89
+
90
+ def embed_via_multilingual_e5_large(query_addresses):
91
+ headers = {
92
+ "Accept" : "application/json",
93
+ "Authorization": f"Bearer {HUGGING_FACE_TOKEN}",
94
+ "Content-Type": "application/json"
95
+ }
96
+
97
+ response = requests.post(EMBEDDING_MODEL_ENDPOINT, headers=headers, json={"inputs": query_addresses})
98
+ response_json = response.json()
99
+
100
+ if 'error' in response_json:
101
+ if response_json['error'] == 'Bad Request: Invalid state':
102
+ raise InferenceEndpointError(InferenceEndpointErrorCode.INVALID_STATE, "Bad Request: Invalid state")
103
+ elif response_json['error'] == '503 Service Unavailable':
104
+ raise InferenceEndpointError(InferenceEndpointErrorCode.SERVICE_UNAVAILABLE, "Service Unavailable")
105
+ else:
106
+ raise InferenceEndpointError(InferenceEndpointErrorCode.UNKNOWN_ERROR, response_json['error'])
107
+
108
+ return response_json
109
+
110
+ def search_via_milvus(query_vector, top_k):
111
+ search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} # MiniLM系はCOSINE推奨
112
+
113
+ results = MILVUS_CLIENT.search(
114
+ collection_name=VECTOR_SEARCH_COLLECTION_NAME,
115
+ data=[query_vector],
116
+ search_params=search_params,
117
+ limit=top_k,
118
+ anns_field='embedding',
119
+ output_fields=['address'],
120
+ )[0]
121
+
122
+ hits = []
123
+ for i, result in enumerate(results, start=1):
124
+ distance = result['distance']
125
+ address = result['entity'].get('address')
126
+ hits.append([i, distance, address])
127
+
128
+ return hits
129
 
130
  with gr.Blocks() as demo:
131
  with gr.Tab("デジ庁API"):
 
177
  search_button = gr.Button(value='検索', variant='primary')
178
  result_dataframe = gr.Dataframe(label="検索結果")
179
 
 
 
180
  def search_address(query_address, top_k):
181
  query_address = preprocess(query_address)
182
 
183
+ wait_time = 30
184
+ max_retries = 5
185
+ for attempt in range(max_retries):
186
+ try:
187
+ query_embeds = embed_via_multilingual_e5_large([query_address])
188
+ break # 成功した場合はループを抜ける
189
+
190
+ except InferenceEndpointError as e:
191
+ if e.code == InferenceEndpointErrorCode.SERVICE_UNAVAILABLE:
192
+ if attempt < max_retries - 1:
193
+ gr.Warning(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 埋め込みモデルの推論エンドポイントが起動中です。{wait_time}秒後にリトライします。", duration=wait_time)
194
+ time.sleep(wait_time) # 30秒待機
195
+ else:
196
+ raise gr.Error(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 最大リトライ回数に達しました。しばらくしてから再度実行してみてください。")
197
+
198
+ elif e.code == InferenceEndpointErrorCode.INVALID_STATE:
199
+ raise gr.Error(f"{InferenceEndpointErrorCode.INVALID_STATE}: 埋め込みモデルの推論エンドポイントが停止中です。再起動するよう管理者に問い合わせてください。")
200
+
201
+ elif e.code == InferenceEndpointErrorCode.UNKNOWN_ERROR:
202
+ raise gr.Error(f"{InferenceEndpointErrorCode.UNKNOWN_ERROR}: {e.message}")
203
+
204
+ hits = search_via_milvus(query_embeds[0], top_k)
205
  df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所'])
206
  return df
207
 
requirements.txt CHANGED
@@ -1,11 +1,7 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu118
2
-
3
  gradio
4
  pandas
5
  numpy
6
- faiss-cpu
7
- sentence-transformers
8
  huggingface-hub
9
- torch==2.5.1
10
  fastapi
11
- uvicorn
 
 
 
 
1
  gradio
2
  pandas
3
  numpy
 
 
4
  huggingface-hub
 
5
  fastapi
6
+ uvicorn
7
+ pymilvus