matsuap commited on
Commit
ba0a056
·
1 Parent(s): 7963713

環境変数の設定を追加し、住所比較機能やベクトル検索機能を実装。Gradioタブを用いて新しいエンドポイントを作成し、テキスト処理機能を強化。必要なライブラリをrequirements.txtに追加。

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +244 -81
  3. requirements.txt +5 -1
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .venv/
2
  embeddings/
3
  embeddings_/
4
- __pycache__/
 
 
1
  .venv/
2
  embeddings/
3
  embeddings_/
4
+ __pycache__/
5
+ .env
app.py CHANGED
@@ -1,4 +1,7 @@
1
  import gradio as gr
 
 
 
2
  import time
3
  import requests
4
  import pandas as pd
@@ -10,24 +13,9 @@ from dotenv import load_dotenv
10
  # .envファイルを読み込む
11
  load_dotenv()
12
 
13
- app = FastAPI()
14
-
15
- @app.post("/replace-circle")
16
- def replace_circle(input_text):
17
- output_text = input_text.replace('◯', '0')
18
- return output_text
19
-
20
- @app.post("/remove-filler")
21
- def remove_filler(input_text):
22
- output_text = input_text
23
- return output_text
24
-
25
- @app.post("/preprocess")
26
- def preprocess(input_text):
27
- output_text = replace_circle(input_text)
28
- output_text = remove_filler(output_text)
29
- return output_text
30
-
31
  # 環境変数からHUGGING_FACE_TOKENを取得
32
  HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
33
  EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
@@ -36,14 +24,99 @@ ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
36
  VECTOR_SEARCH_ENDPOINT = os.environ.get('VECTOR_SEARCH_ENDPOINT')
37
  VECTOR_SEARCH_TOKEN = os.environ.get('VECTOR_SEARCH_TOKEN')
38
  VECTOR_SEARCH_COLLECTION_NAME = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def init_milvus():
41
- milvus_client = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
42
- print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
43
 
44
- return milvus_client
 
 
 
 
45
 
46
- MILVUS_CLIENT = init_milvus()
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # 47都道府県のリスト
49
  prefs = [
@@ -56,34 +129,6 @@ prefs = [
56
  '熊本県', '大分県', '宮崎県', '鹿児島県', '沖縄県'
57
  ]
58
 
59
- examples = [
60
- '私の住所は京都府京都市右京区太秦青木元町4-10です。',
61
- '京都府京都市右京区太秦青木元町4-10',
62
- '京都府京都市右京区太秦青木元町4-10ダックス101号室',
63
- '京都府宇治市伊勢田町名木1-1-4ダックス101号室',
64
- '東京都渋谷区道玄坂1-12-1',
65
- '私の住所は東京都渋谷区道玄坂1-12-1です。',
66
- '私の住所は東京都しぶや道玄坂1の12の1です。',
67
- '東京都渋谷区道玄坂1の12の1で契約しています。',
68
- '秋田県秋田市山王四丁目1番1号です。',
69
- '東京 墨田区 押上 1丁目1',
70
- '三重県伊勢市宇治館町',
71
- '住所は 030-0803 青森県青森市安方1丁目1−40になります。',
72
- '東京都大島町差木地 字クダッチ',
73
- '前橋市大手町1丁目1番地1',
74
- '東京都渋谷区表参道の3の5の6。',
75
- '琉球圏尾張町3の5の6に住んでます。',
76
- '3254987の場所です。',
77
- '大阪府でした。',
78
- '1940923の東京都渋谷区道玄坂一丁目。渋谷マークシティウェスト23階です。',
79
- '名前は山田太郎です。',
80
- 'はい。名古屋、あ、愛知県名古屋市南里2の3の4だと思います。',
81
- 'ー',
82
- '少し待ってください。',
83
- ]
84
-
85
- from enum import Enum
86
-
87
  class InferenceEndpointErrorCode(Enum):
88
  INVALID_STATE = 400
89
  SERVICE_UNAVAILABLE = 503
@@ -115,11 +160,11 @@ def embed_via_multilingual_e5_large(query_addresses):
115
 
116
  return response_json
117
 
118
- def search_via_milvus(query_vector, top_k):
119
  search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} # MiniLM系はCOSINE推奨
120
 
121
  results = MILVUS_CLIENT.search(
122
- collection_name=VECTOR_SEARCH_COLLECTION_NAME,
123
  data=[query_vector],
124
  search_params=search_params,
125
  limit=top_k,
@@ -131,11 +176,129 @@ def search_via_milvus(query_vector, top_k):
131
  for i, result in enumerate(results, start=1):
132
  distance = result['distance']
133
  address = result['entity'].get('address')
134
- hits.append([i, distance, address])
 
 
135
 
136
  return hits
137
 
138
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  with gr.Tab("デジ庁API"):
140
  with gr.Row():
141
  with gr.Column():
@@ -161,6 +324,9 @@ with gr.Blocks() as demo:
161
  'oaza_cho': result['oaza_cho'],
162
  'chome': result['chome'],
163
  'koaza': result['koaza'],
 
 
 
164
  'prc_num1': result['prc_num1'],
165
  'prc_num2': result['prc_num2'],
166
  'prc_num3': result['prc_num3'],
@@ -176,6 +342,7 @@ with gr.Blocks() as demo:
176
  outputs=[result_tb, result_df],
177
  )
178
 
 
179
  with gr.Tab("ベクトル検索"):
180
  with gr.Row():
181
  with gr.Column():
@@ -183,40 +350,36 @@ with gr.Blocks() as demo:
183
  gr.Examples(examples=examples, inputs=[address_input])
184
  top_k_input = gr.Slider(minimum=1, maximum=100, step=1, value=5, label='検索数top-k')
185
  search_button = gr.Button(value='検索', variant='primary')
186
- result_dataframe = gr.Dataframe(label="検索結果")
 
 
187
 
188
  def search_address(query_address, top_k):
189
- query_address = preprocess(query_address)
190
-
191
- wait_time = 30
192
- max_retries = 5
193
- for attempt in range(max_retries):
194
- try:
195
- query_embeds = embed_via_multilingual_e5_large([query_address])
196
- break # 成功した場合はループを抜ける
197
 
198
- except InferenceEndpointError as e:
199
- if e.code == InferenceEndpointErrorCode.SERVICE_UNAVAILABLE:
200
- if attempt < max_retries - 1:
201
- gr.Warning(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 埋め込みモデルの推論エンドポイントが起動中です。{wait_time}秒後にリトライします。", duration=wait_time)
202
- time.sleep(wait_time) # 30秒待機
203
- else:
204
- raise gr.Error(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 最大リトライ回数に達しました。しばらくしてから再度実行してみてください。")
205
-
206
- elif e.code == InferenceEndpointErrorCode.INVALID_STATE:
207
- raise gr.Error(f"{InferenceEndpointErrorCode.INVALID_STATE}: 埋め込みモデルの推論エンドポイントが停止中です。再起動するよう管理者に問い合わせてください。")
208
-
209
- elif e.code == InferenceEndpointErrorCode.UNKNOWN_ERROR:
210
- raise gr.Error(f"{InferenceEndpointErrorCode.UNKNOWN_ERROR}: {e.message}")
211
 
212
- hits = search_via_milvus(query_embeds[0], top_k)
213
- df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所'])
214
- return df
215
 
216
  search_button.click(
217
  fn=search_address,
218
  inputs=[address_input, top_k_input],
219
- outputs=[result_dataframe]
220
  )
221
 
 
 
 
 
 
222
  app = gr.mount_gradio_app(app, demo, path='/')
 
1
  import gradio as gr
2
+ import spacy
3
+ from normalize_japanese_addresses import normalize
4
+ from enum import Enum
5
  import time
6
  import requests
7
  import pandas as pd
 
13
  # .envファイルを読み込む
14
  load_dotenv()
15
 
16
+ # =========================
17
+ # Global variables
18
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # 環境変数からHUGGING_FACE_TOKENを取得
20
  HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
21
  EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
 
24
  VECTOR_SEARCH_ENDPOINT = os.environ.get('VECTOR_SEARCH_ENDPOINT')
25
  VECTOR_SEARCH_TOKEN = os.environ.get('VECTOR_SEARCH_TOKEN')
26
  VECTOR_SEARCH_COLLECTION_NAME = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME')
27
+ VECTOR_SEARCH_COLLECTION_NAME_V2 = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME_V2')
28
+
29
+ MILVUS_CLIENT = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
30
+ print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
31
+
32
+ # =========================
33
+ # Utilitiy functions
34
+ # =========================
35
+ def split_address(normalized_address):
36
+ splits = normalize(normalized_address)
37
+ return splits
38
+
39
+ def compare(normalized_address1, normalized_address2):
40
+ split1 = split_address(normalized_address1)
41
+ split2 = split_address(normalized_address2)
42
+
43
+ result = {
44
+ 'pref': False,
45
+ 'city': False,
46
+ 'town': False,
47
+ 'addr': False,
48
+ }
49
+
50
+ for key in result.keys():
51
+ if split1[key] == split2[key]:
52
+ result[key] = True
53
+
54
+ return all(result.values())
55
+
56
+ def vector_search(query_address, top_k):
57
+ wait_time = 30
58
+ max_retries = 5
59
+ for attempt in range(max_retries):
60
+ try:
61
+ query_embeds = embed_via_multilingual_e5_large([query_address])
62
+ break # 成功した場合はループを抜ける
63
+
64
+ except InferenceEndpointError as e:
65
+ if e.code == InferenceEndpointErrorCode.SERVICE_UNAVAILABLE:
66
+ if attempt < max_retries - 1:
67
+ gr.Warning(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 埋め込みモデルの推論エンドポイントが起動中です。{wait_time}秒後にリトライします。", duration=wait_time)
68
+ time.sleep(wait_time) # 30秒待機
69
+ else:
70
+ raise gr.Error(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 最大リトライ回数に達しました。しばらくしてから再度実行してみてください。")
71
+
72
+ elif e.code == InferenceEndpointErrorCode.INVALID_STATE:
73
+ raise gr.Error(f"{InferenceEndpointErrorCode.INVALID_STATE}: 埋め込みモデルの推論エンドポイントが停止中です。再起動するよう管理者に問い合わせてください。")
74
+
75
+ elif e.code == InferenceEndpointErrorCode.UNKNOWN_ERROR:
76
+ raise gr.Error(f"{InferenceEndpointErrorCode.UNKNOWN_ERROR}: {e.message}")
77
+
78
+ '''
79
+ hits = search_via_milvus(query_embeds[0], top_k, VECTOR_SEARCH_COLLECTION_NAME_V2)
80
+
81
+ if hits:
82
+ normalized = hits[0][-1]
83
+
84
+ else:
85
+ hits = search_via_milvus(query_embeds[0], top_k, VECTOR_SEARCH_COLLECTION_NAME)
86
+ normalized = hits[0][-1]
87
+ '''
88
+ hits = search_via_milvus(query_embeds[0], top_k, VECTOR_SEARCH_COLLECTION_NAME)
89
+ return hits
90
+
91
+ def replace_circle(input_text):
92
+ output_text = input_text.replace('◯', '0')
93
+ return output_text
94
+
95
+ def remove_filler(input_text: str) -> str:
96
+ """
97
+ GiNZAを用いて日本語テキストからフィラーを除去する関数。
98
 
99
+ Parameters:
100
+ text (str): 入力テキスト。
 
101
 
102
+ Returns:
103
+ str: フィラーを除去したテキスト。
104
+ """
105
+ # GiNZAモデルの読み込み
106
+ nlp = spacy.load("ja_ginza")
107
 
108
+ # テキストの解析
109
+ doc = nlp(input_text)
110
+
111
+ # フィラーを除去したテキストの生成
112
+ cleaned_text = ''.join([token.text for token in doc if token.tag_ != "感動詞-フィラー"])
113
+
114
+ return cleaned_text
115
+
116
+ def preprocess(input_text):
117
+ output_text = replace_circle(input_text)
118
+ output_text = remove_filler(output_text)
119
+ return output_text
120
 
121
  # 47都道府県のリスト
122
  prefs = [
 
129
  '熊本県', '大分県', '宮崎県', '鹿児島県', '沖縄県'
130
  ]
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  class InferenceEndpointErrorCode(Enum):
133
  INVALID_STATE = 400
134
  SERVICE_UNAVAILABLE = 503
 
160
 
161
  return response_json
162
 
163
+ def search_via_milvus(query_vector, top_k, collection_name, thresh=0.9):
164
  search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} # MiniLM系はCOSINE推奨
165
 
166
  results = MILVUS_CLIENT.search(
167
+ collection_name=collection_name,
168
  data=[query_vector],
169
  search_params=search_params,
170
  limit=top_k,
 
176
  for i, result in enumerate(results, start=1):
177
  distance = result['distance']
178
  address = result['entity'].get('address')
179
+
180
+ if distance >= thresh:
181
+ hits.append([i, distance, address])
182
 
183
  return hits
184
 
185
+ # =========================
186
+ # FastAPI definition
187
+ # =========================
188
+ app = FastAPI()
189
+
190
+ @app.post("/compare-two-addresses")
191
+ def compare_two_addresses(address1, address2):
192
+ preprocessed1 = preprocess(address1)
193
+ preprocessed2 = preprocess(address2)
194
+ hits1 = vector_search(preprocessed1, top_k=1)
195
+ hits2 = vector_search(preprocessed2, top_k=1)
196
+ normalized1 = hits1[0][-1]
197
+ normalized2 = hits2[0][-1]
198
+ result = compare(normalized1, normalized2)
199
+ return result
200
+
201
+ # =========================
202
+ # Gradio tabs definition
203
+ # =========================
204
+ def create_endpoint_test_tab():
205
+ def create_replace_circle_tab():
206
+ with gr.Tab("replace_circle"):
207
+ in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください')
208
+ out_tb = gr.Textbox(label='アウトプット')
209
+ exe_button = gr.Button(value='実行', variant='primary')
210
+ exe_button.click(
211
+ fn=replace_circle,
212
+ inputs=[in_tb],
213
+ outputs=[out_tb],
214
+ )
215
+ def create_remove_filler_tab():
216
+ with gr.Tab("remove_filler"):
217
+ in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください')
218
+ out_tb = gr.Textbox(label='アウトプット')
219
+ exe_button = gr.Button(value='実行', variant='primary')
220
+ exe_button.click(
221
+ fn=remove_filler,
222
+ inputs=[in_tb],
223
+ outputs=[out_tb],
224
+ )
225
+ def create_preprocess_tab():
226
+ with gr.Tab("preprocess"):
227
+ in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください')
228
+ out_tb = gr.Textbox(label='アウトプット')
229
+ exe_button = gr.Button(value='実行', variant='primary')
230
+ exe_button.click(
231
+ fn=preprocess,
232
+ inputs=[in_tb],
233
+ outputs=[out_tb],
234
+ )
235
+ def create_compare_two_addresses_tab():
236
+ with gr.Tab("compare_two_addresses"):
237
+ in_tb1 = gr.Textbox(label='住所1 (顧客が発言した住所)', placeholder='住所を入力してください')
238
+ in_tb2 = gr.Textbox(label='住所2 (CRM 内に格納されている住所)', placeholder='住所を入力してください')
239
+ out_tb = gr.Textbox(label='アウトプット')
240
+ exe_button = gr.Button(value='実行', variant='primary')
241
+ exe_button.click(
242
+ fn=compare_two_addresses,
243
+ inputs=[in_tb1, in_tb2],
244
+ outputs=[out_tb],
245
+ )
246
+ def create_normalize_address_tab():
247
+ with gr.Tab("normalize_address"):
248
+ in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください')
249
+ out_tb = gr.Textbox(label='アウトプット')
250
+ exe_button = gr.Button(value='実行', variant='primary')
251
+ exe_button.click(
252
+ fn=lambda address: vector_search(address, top_k=1)[0][-1],
253
+ inputs=[in_tb],
254
+ outputs=[out_tb],
255
+ )
256
+ def create_split_address_tab():
257
+ with gr.Tab("split_address"):
258
+ in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください')
259
+ out_tb = gr.Textbox(label='アウトプット')
260
+ exe_button = gr.Button(value='実行', variant='primary')
261
+ exe_button.click(
262
+ fn=split_address,
263
+ inputs=[in_tb],
264
+ outputs=[out_tb],
265
+ )
266
+
267
+ with gr.Tab("関数テスト"):
268
+ create_compare_two_addresses_tab()
269
+ create_replace_circle_tab()
270
+ create_remove_filler_tab()
271
+ create_preprocess_tab()
272
+ create_normalize_address_tab()
273
+ create_split_address_tab()
274
+
275
+ examples = [
276
+ '私の住所は京都府京都市右京区太秦青木元町4-10です。',
277
+ '京都府京都市右京区太秦青木元町4-10',
278
+ '京都府京都市右京区太秦青木元町4-10ダックス101号室',
279
+ '京都府宇治市伊勢田町名木1-1-4ダックス101号室',
280
+ '東京都渋谷区道玄坂1-12-1',
281
+ '私の住所は東京都渋谷区道玄坂1-12-1です。',
282
+ '私の住所は東京都しぶや道玄坂1の12の1です。',
283
+ '東京都渋谷区道玄坂1の12の1で契約しています。',
284
+ '秋田県秋田市山王四丁目1番1号です。',
285
+ '東京 墨田区 押上 1丁目1',
286
+ '三重県伊勢市宇治館町',
287
+ '住所は 030-0803 青森県青森市安方1丁目1−40になります。',
288
+ '東京都大島町差木地 字クダッチ',
289
+ '前橋市大手町1丁目1番地1',
290
+ '東京都渋谷区表参道の3の5の6。',
291
+ '琉球圏尾張町3の5の6に住んでます。',
292
+ '3254987の場所です。',
293
+ '大阪府でした。',
294
+ '1940923の東京都渋谷区道玄坂一丁目。渋谷マークシティウェスト23階です。',
295
+ '名前は山田太郎です。',
296
+ 'はい。名古屋、あ、愛知県名古屋市南里2の3の4だと思います。',
297
+ 'ー',
298
+ '少し待ってください。',
299
+ ]
300
+
301
+ def create_digital_agency_tab():
302
  with gr.Tab("デジ庁API"):
303
  with gr.Row():
304
  with gr.Column():
 
324
  'oaza_cho': result['oaza_cho'],
325
  'chome': result['chome'],
326
  'koaza': result['koaza'],
327
+ 'blk_num': result['blk_num'],
328
+ 'rsdt_num': result['rsdt_num'],
329
+ 'rsdt_num2': result['rsdt_num2'],
330
  'prc_num1': result['prc_num1'],
331
  'prc_num2': result['prc_num2'],
332
  'prc_num3': result['prc_num3'],
 
342
  outputs=[result_tb, result_df],
343
  )
344
 
345
+ def create_vector_search_tab():
346
  with gr.Tab("ベクトル検索"):
347
  with gr.Row():
348
  with gr.Column():
 
350
  gr.Examples(examples=examples, inputs=[address_input])
351
  top_k_input = gr.Slider(minimum=1, maximum=100, step=1, value=5, label='検索数top-k')
352
  search_button = gr.Button(value='検索', variant='primary')
353
+ result_tb = gr.Textbox(label='正規化後')
354
+ search_result_df = gr.Dataframe(label="検索結果")
355
+ result_df = gr.Dataframe(label="正規化後(分割)", wrap=True)
356
 
357
  def search_address(query_address, top_k):
358
+ preprocessed = preprocess(query_address)
359
+ hits = vector_search(preprocessed, top_k)
360
+ normalized = hits[0]
361
+ search_result_df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所'])
362
+ splits = split_address(normalized)
 
 
 
363
 
364
+ data = {
365
+ 'pref': splits['pref'],
366
+ 'city': splits['city'],
367
+ 'town': splits['town'],
368
+ 'addr': splits['addr'],
369
+ }
370
+ result_df = pd.DataFrame([data])
 
 
 
 
 
 
371
 
372
+ return search_result_df, normalized, result_df
 
 
373
 
374
  search_button.click(
375
  fn=search_address,
376
  inputs=[address_input, top_k_input],
377
+ outputs=[search_result_df, result_tb, result_df]
378
  )
379
 
380
+ with gr.Blocks() as demo:
381
+ create_endpoint_test_tab()
382
+ create_vector_search_tab()
383
+ create_digital_agency_tab()
384
+
385
  app = gr.mount_gradio_app(app, demo, path='/')
requirements.txt CHANGED
@@ -4,4 +4,8 @@ numpy
4
  huggingface-hub
5
  fastapi
6
  uvicorn
7
- pymilvus
 
 
 
 
 
4
  huggingface-hub
5
  fastapi
6
  uvicorn
7
+ pymilvus
8
+ spacy
9
+ normalize-japanese-addresses
10
+ ginza
11
+ ja-ginza