Spaces:
Sleeping
Sleeping
import gradio as gr | |
import zipfile | |
from tqdm import tqdm | |
from sklearn.metrics.pairwise import cosine_similarity | |
from pathlib import Path | |
import spacy | |
from normalize_japanese_addresses import normalize | |
from enum import Enum | |
import time | |
import requests | |
import pandas as pd | |
import os | |
from fastapi import FastAPI, Request | |
from pymilvus import MilvusClient | |
from dotenv import load_dotenv | |
import time | |
from contextlib import contextmanager | |
import numpy as np | |
import re | |
import os | |
from openai import AzureOpenAI | |
# .envファイルを読み込む | |
load_dotenv() | |
# ========================= | |
# Global variables | |
# ========================= | |
TARGET_DIR = Path(os.environ.get('TARGET_DIR')) | |
HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN') | |
EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT') | |
ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT') | |
VECTOR_SEARCH_ENDPOINT = os.environ.get('VECTOR_SEARCH_ENDPOINT') | |
VECTOR_SEARCH_TOKEN = os.environ.get('VECTOR_SEARCH_TOKEN') | |
VECTOR_SEARCH_COLLECTION_NAME = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME') | |
VECTOR_SEARCH_COLLECTION_NAME_V2 = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME_V2') | |
GOOGLE_SEARCH_API_KEY = os.environ.get('GOOGLE_SEARCH_API_KEY') | |
GOOGLE_SEARCH_ENGINE_ID = os.environ.get('GOOGLE_SEARCH_ENGINE_ID') | |
MILVUS_CLIENT = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN) | |
print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully") | |
# 47都道府県のリスト | |
prefs = [ | |
'北海道', '青森県', '岩手県', '宮城県', '秋田県', '山形県', '福島県', | |
'茨城県', '栃木県', '群馬県', '埼玉県', '千葉県', '東京都', '神奈川県', | |
'新潟県', '富山県', '石川県', '福井県', '山梨県', '長野県', '岐阜県', | |
'静岡県', '愛知県', '三重県', '滋賀県', '京都府', '大阪府', '兵庫県', | |
'奈良県', '和歌山県', '鳥取県', '島根県', '岡山県', '広島県', '山口県', | |
'徳島県', '香川県', '愛媛県', '高知県', '福岡県', '佐賀県', '長崎県', | |
'熊本県', '大分県', '宮崎県', '鹿児島県', '沖縄県' | |
] | |
# ---------------------------- | |
# Azure OpenAI API | |
# ---------------------------- | |
client = AzureOpenAI( | |
api_key=os.getenv("AZURE_OPENAI_API_KEY"), | |
api_version="2025-03-01-preview", | |
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT") | |
) | |
# ---------------------------- | |
# Download mt_city_all.csv | |
# ---------------------------- | |
temp_dir = Path('temp') | |
temp_dir.mkdir(exist_ok=True) | |
city_all_url = 'https://catalog.registries.digital.go.jp/rsc/address/mt_city_all.csv.zip' | |
zip_file_path = temp_dir / 'mt_city_all.csv.zip' | |
# すでにファイルが存在する場合はダウンロードをスキップ | |
if not os.path.exists(zip_file_path): | |
# ZIPファイルをダウンロード | |
response = requests.get(city_all_url) | |
with open(zip_file_path, 'wb') as f: | |
f.write(response.content) | |
# target_dir直下に解凍 | |
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
zip_ref.extractall(TARGET_DIR) | |
# ------------------------------------ | |
# Download mt_parcel_cityXXXXXX.csv | |
# ------------------------------------ | |
city_all_file = TARGET_DIR / 'mt_city_all.csv' | |
city_all_df = pd.read_csv(city_all_file) | |
lg_codes = city_all_df['lg_code'].tolist() | |
print('lg_codes', len(lg_codes)) | |
for lg_code in tqdm(lg_codes): | |
parcel_url = f'https://catalog.registries.digital.go.jp/rsc/address/mt_parcel_city{lg_code:06d}.csv.zip' | |
zip_file_path = temp_dir / f'mt_parcel_city{lg_code:06d}.csv.zip' | |
if not os.path.exists(TARGET_DIR / 'parcel' / f'mt_parcel_city{lg_code:06d}.csv'): | |
response = requests.get(parcel_url) | |
if response.status_code == 200: # URLが存在する場合のみ処理を続ける | |
with open(zip_file_path, 'wb') as f: | |
f.write(response.content) | |
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
zip_ref.extractall(TARGET_DIR / 'parcel') | |
time.sleep(0.2) # ダウンロードごとに200msのスリープを入れる | |
# ------------------------------------ | |
# Download mt_rsdtdsp_rsdt_prefXX.csv | |
# ------------------------------------ | |
pref_codes = list(set([('%06d' % lg_code)[0:2] for lg_code in lg_codes])) | |
for pref_code in tqdm(pref_codes): | |
rsdt_url = f'https://catalog.registries.digital.go.jp/rsc/address/mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip' | |
zip_file_path = temp_dir / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip' | |
if not os.path.exists(TARGET_DIR / 'rsdt' / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv'): | |
response = requests.get(rsdt_url) | |
if response.status_code == 200: # URLが存在する場合のみ処理を続ける | |
with open(zip_file_path, 'wb') as f: | |
f.write(response.content) | |
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
zip_ref.extractall(TARGET_DIR / 'rsdt') | |
time.sleep(0.2) # ダウンロードごとに200msのスリープを入れる | |
# 一時ディレクトリを削除 | |
for file in temp_dir.iterdir(): | |
file.unlink() | |
temp_dir.rmdir() | |
# ========================= | |
# Utilitiy functions | |
# ========================= | |
def measure(label="処理"): | |
start = time.time() | |
yield | |
end = time.time() | |
print(f"{label} 実行時間: {end - start:.6f} 秒") | |
def get_spelling(query_address): | |
# APIリクエストを作成 | |
url = f'https://www.googleapis.com/customsearch/v1?key={GOOGLE_SEARCH_API_KEY}&cx={GOOGLE_SEARCH_ENGINE_ID}&q={query_address}' | |
# リクエストを送信 | |
response = requests.get(url) | |
results = response.json() | |
return results.get('spelling', {}).get('correctedQuery', '') | |
def convert_zenkaku_to_hankaku(text): | |
zenkaku_numbers = '0123456789' | |
hankaku_numbers = '0123456789' | |
zenkaku_hyphens = '-' | |
hankaku_hyphens = '-' | |
translation_table = str.maketrans(zenkaku_numbers + zenkaku_hyphens, hankaku_numbers + hankaku_hyphens) | |
return text.translate(translation_table) | |
ADDRESS_REGEX = re.compile( | |
r'^' | |
r'(?P<address>' | |
r'.+?[都道府県]' # 都道府県 | |
r'.+?[市区町村]' # 市区町村 | |
r'.*?' # 町名など(最小マッチ) | |
r'[0-90-9]+' # 番地の先頭数字 | |
r'(?:[-ー−–][0-90-9]+)*' # 「-数字」の繰返し | |
r'(?:(?:丁目|番地|番|号)' # 「丁目」「番地」「番」「号」 | |
r'(?:[0-90-9]+' # のあとに続く数字 | |
r'(?:[-ー−–][0-90-9]+)*' # 「-数字」の繰返し | |
r')?' | |
r')*' # 上記ユニットを何度でも繰返し | |
r')' | |
r'(?P<building>.*)' # 残りを建物名としてキャプチャ | |
r'$' | |
) | |
def split_address_building(address: str) -> dict: | |
m = ADDRESS_REGEX.match(address) | |
if not m: | |
return {'address': address, 'building': ''} | |
return { | |
'address': m.group('address').strip(), | |
'building': m.group('building').strip() | |
} | |
def split_address_building_with_gpt(query_address: str) -> dict: | |
class SplittedAddress(BaseModel): | |
address: str | |
building: str | |
response = client.responses.parse( | |
model="gpt-4o-mini", | |
input=[ | |
{"role": "system", "content": "Extract the event information."}, | |
{ | |
"role": "user", | |
"content": f"与えられた住所をaddressとbuildingに分けろ:{query_address}", | |
}, | |
], | |
text_format=SplittedAddress, | |
) | |
response = response.output_parsed | |
return { | |
'address': response.address, | |
'building': response.building, | |
} | |
def split_address(normalized_address): | |
splits = normalize(normalized_address) | |
return splits | |
def compare(normalized_address1, normalized_address2): | |
split1 = split_address(normalized_address1) | |
split2 = split_address(normalized_address2) | |
result = { | |
'pref': False, | |
'city': False, | |
'town': False, | |
'addr': False, | |
} | |
for key in result.keys(): | |
if split1[key] == split2[key]: | |
result[key] = True | |
return all(result.values()) | |
def vector_search(query_address, top_k): | |
wait_time = 30 | |
max_retries = 5 | |
for attempt in range(max_retries): | |
try: | |
with measure('vector_search - embed_via_multilingual_e5_large'): | |
query_embeds = embed_via_multilingual_e5_large([query_address]) | |
break # 成功した場合はループを抜ける | |
except InferenceEndpointError as e: | |
if e.code == InferenceEndpointErrorCode.SERVICE_UNAVAILABLE: | |
if attempt < max_retries - 1: | |
gr.Warning(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 埋め込みモデルの推論エンドポイントが起動中です。{wait_time}秒後にリトライします。", duration=wait_time) | |
time.sleep(wait_time) # 30秒待機 | |
else: | |
raise gr.Error(f"{InferenceEndpointErrorCode.SERVICE_UNAVAILABLE}: 最大リトライ回数に達しました。しばらくしてから再度実行してみてください。") | |
elif e.code == InferenceEndpointErrorCode.INVALID_STATE: | |
raise gr.Error(f"{InferenceEndpointErrorCode.INVALID_STATE}: 埋め込みモデルの推論エンドポイントが停止中です。再起動するよう管理者に問い合わせてください。") | |
elif e.code == InferenceEndpointErrorCode.UNKNOWN_ERROR: | |
raise gr.Error(f"{InferenceEndpointErrorCode.UNKNOWN_ERROR}: {e.message}") | |
with measure('vector_search - search_via_milvus'): | |
hits = search_via_milvus(query_embeds[0], top_k, VECTOR_SEARCH_COLLECTION_NAME) | |
return hits | |
def replace_circle(input_text): | |
output_text = input_text.replace('◯', '0') | |
return output_text | |
def remove_filler(input_text: str) -> str: | |
""" | |
GiNZAを用いて日本語テキストからフィラーを除去する関数。 | |
Parameters: | |
text (str): 入力テキスト。 | |
Returns: | |
str: フィラーを除去したテキスト。 | |
""" | |
# GiNZAモデルの読み込み | |
nlp = spacy.load("ja_ginza") | |
# テキストの解析 | |
doc = nlp(input_text) | |
# フィラーを除去したテキストの生成 | |
cleaned_text = ''.join([token.text for token in doc if token.tag_ != "感動詞-フィラー"]) | |
return cleaned_text | |
def remove_left_of_pref(text): | |
for pref in prefs: | |
pref_index = text.find(pref) | |
if pref_index != -1: | |
return text[pref_index:] # 都道府県名の位置から右側の文字列を返す | |
return text # 都道府県名が見つからない場合は元のテキストを返す | |
def preprocess(input_text): | |
output_text = remove_left_of_pref(input_text) | |
output_text = replace_circle(output_text) | |
output_text = remove_filler(output_text) | |
return output_text | |
class InferenceEndpointErrorCode(Enum): | |
INVALID_STATE = 400 | |
SERVICE_UNAVAILABLE = 503 | |
UNKNOWN_ERROR = 520 | |
class InferenceEndpointError(Exception): | |
def __init__(self, code: InferenceEndpointErrorCode, message="エラー"): | |
self.code = code | |
self.message = message | |
super().__init__(self.message) | |
def embed_via_multilingual_e5_large(query_addresses): | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {HUGGING_FACE_TOKEN}", | |
"Content-Type": "application/json" | |
} | |
all_responses = [] | |
for i in range(0, len(query_addresses), 2048): | |
chunk = query_addresses[i:i + 2048] | |
response = requests.post(EMBEDDING_MODEL_ENDPOINT, headers=headers, json={"inputs": chunk}) | |
response_json = response.json() | |
if 'error' in response_json: | |
if response_json['error'] == 'Bad Request: Invalid state': | |
raise InferenceEndpointError(InferenceEndpointErrorCode.INVALID_STATE, "Bad Request: Invalid state") | |
elif response_json['error'] == '503 Service Unavailable': | |
raise InferenceEndpointError(InferenceEndpointErrorCode.SERVICE_UNAVAILABLE, "Service Unavailable") | |
else: | |
raise InferenceEndpointError(InferenceEndpointErrorCode.UNKNOWN_ERROR, response_json['error']) | |
all_responses.extend(response_json) | |
return all_responses | |
def search_via_milvus(query_vector, top_k, collection_name, thresh=0.0): | |
search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} # MiniLM系はCOSINE推奨 | |
results = MILVUS_CLIENT.search( | |
collection_name=collection_name, | |
data=[query_vector], | |
search_params=search_params, | |
limit=top_k, | |
anns_field='embedding', | |
output_fields=['address', 'pref', 'county', 'city', 'ward', 'oaza_cho', 'chome', 'koaza'], | |
)[0] | |
hits = [] | |
for i, result in enumerate(results, start=1): | |
distance = result['distance'] | |
address = result['entity'].get('address') | |
pref = result['entity'].get('pref') | |
county = result['entity'].get('county') | |
city = result['entity'].get('city') | |
ward = result['entity'].get('ward') | |
oaza_cho = result['entity'].get('oaza_cho') | |
chome = result['entity'].get('chome') | |
koaza = result['entity'].get('koaza') | |
if distance >= thresh: | |
hits.append([i, distance, address, pref, county, city, ward, oaza_cho, chome, koaza]) | |
return hits | |
def get_lg_code(pref, county, city, ward): | |
city_all_file = TARGET_DIR / 'mt_city_all.csv' | |
city_all_df = pd.read_csv(city_all_file) | |
city_all_df_temp = city_all_df[city_all_df['pref'] == pref] | |
city_name1 = city_all_df_temp['county'].fillna('') + city_all_df_temp['city'].fillna('') + city_all_df_temp['ward'].fillna('') | |
city_name2 = county + city + ward | |
lg_codes = city_all_df_temp[city_name1 == city_name2]['lg_code'].values | |
if len(lg_codes) > 1: | |
raise Exception('Too many lg_code') | |
return lg_codes[0] | |
def get_addresses_with_parcel(pref, county, city, ward, oaza_cho, chome, koaza): | |
lg_code = get_lg_code(pref, county, city, ward) | |
parcel_city_file = TARGET_DIR / 'parcel' / f'mt_parcel_city{lg_code:06d}.csv' | |
if not os.path.exists(parcel_city_file): | |
raise gr.Error('Not found: ', parcel_city_file) | |
parcel_city_df = pd.read_csv(parcel_city_file) | |
cities = parcel_city_df['city'].fillna('') | |
wards = parcel_city_df['ward'].fillna('') | |
oaza_chos = parcel_city_df['oaza_cho'].fillna('') | |
chomes = parcel_city_df['chome'].fillna('') | |
koazas = parcel_city_df['koaza'].fillna('') | |
city_name1 = cities + wards | |
city_name2 = county + city + ward | |
city_mask = city_name1 == city_name2 | |
town_name1 = oaza_chos + chomes | |
town_name2 = oaza_cho + chome | |
town_mask = town_name1 == town_name2 | |
koaza_mask = koazas == koaza | |
parcel_city_df_filtered = parcel_city_df[city_mask & town_mask & koaza_mask] | |
if len(parcel_city_df_filtered) == 0: | |
return [pref + county + city + ward + oaza_cho + chome + koaza] | |
cities = parcel_city_df_filtered['city'].fillna('') | |
wards = parcel_city_df_filtered['ward'].fillna('') | |
oaza_chos = parcel_city_df_filtered['oaza_cho'].fillna('') | |
chomes = parcel_city_df_filtered['chome'].fillna('') | |
koazas = parcel_city_df_filtered['koaza'].fillna('') | |
prc_num1s = parcel_city_df_filtered['prc_num1'].fillna(9999).astype(int).astype(str).replace('9999', '') | |
prc_num2s = parcel_city_df_filtered['prc_num2'].fillna(9999).astype(int).astype(str).replace('9999', '') | |
prc_num3s = parcel_city_df_filtered['prc_num3'].fillna(9999).astype(int).astype(str).replace('9999', '') | |
# アドレスを生成 | |
return [ | |
f"{pref}{_city}{_ward}{_oaza_cho}{_chome}{_koaza}{_prc_num1}" + | |
(f"-{_prc_num2}" if _prc_num2 else '') + | |
(f"-{_prc_num3}" if _prc_num3 else '') | |
for _city, _ward, _oaza_cho, _chome, _koaza, _prc_num1, _prc_num2, _prc_num3 in zip( | |
cities, wards, oaza_chos, chomes, koazas, prc_num1s, prc_num2s, prc_num3s | |
) | |
] | |
def get_pref_code(pref): | |
return prefs.index(pref) + 1 | |
def get_addresses_with_rsdtdsp(pref, county, city, ward, oaza_cho, chome, koaza): | |
pref_code = get_pref_code(pref) | |
rsdtdsp_file = TARGET_DIR / 'rsdt' / f'mt_rsdtdsp_rsdt_pref{pref_code:02d}.csv' | |
if not os.path.exists(rsdtdsp_file): | |
raise gr.Error(f'Not found: {rsdtdsp_file}') | |
rsdtdsp_df = pd.read_csv(rsdtdsp_file) | |
city_name1 = rsdtdsp_df['city'].fillna('') + rsdtdsp_df['ward'].fillna('') | |
city_name2 = county + city + ward | |
city_mask = city_name1 == city_name2 | |
town_name1 = rsdtdsp_df['oaza_cho'].fillna('') + rsdtdsp_df['chome'].fillna('') | |
town_name2 = oaza_cho + chome | |
town_mask = town_name1 == town_name2 | |
koaza_mask = rsdtdsp_df['koaza'].fillna('') == koaza | |
rsdtdsp_df_filtered = rsdtdsp_df[city_mask & town_mask & koaza_mask] | |
if len(rsdtdsp_df_filtered) == 0: | |
return [pref + county + city + ward + oaza_cho + chome + koaza] | |
cities = rsdtdsp_df_filtered['city'].fillna('') | |
wards = rsdtdsp_df_filtered['ward'].fillna('') | |
oaza_chos = rsdtdsp_df_filtered['oaza_cho'].fillna('') | |
chomes = rsdtdsp_df_filtered['chome'].fillna('') | |
koazas = rsdtdsp_df_filtered['koaza'].fillna('') | |
blk_nums = rsdtdsp_df_filtered['blk_num'].fillna(9999).astype(int).astype(str).replace('9999', '') | |
rsdt_nums = rsdtdsp_df_filtered['rsdt_num'].fillna(9999).astype(int).astype(str).replace('9999', '') | |
rsdt_num2s = rsdtdsp_df_filtered['rsdt_num2'].fillna(9999).astype(int).astype(str).replace('9999', '') | |
# アドレスを生成 | |
return [ | |
f"{pref}{_city}{_ward}{_oaza_cho}{_chome}{_koaza}{_blk_num}" + | |
(f"-{_rsdt_num}" if _rsdt_num else '') + | |
(f"-{_rsdt_num2}" if _rsdt_num2 else '') | |
for _city, _ward, _oaza_cho, _chome, _koaza, _blk_num, _rsdt_num, _rsdt_num2 in zip( | |
cities, wards, oaza_chos, chomes, koazas, blk_nums, rsdt_nums, rsdt_num2s) | |
] | |
def compare_two_addresses(address1, address2): | |
preprocessed1 = preprocess(address1) | |
preprocessed2 = preprocess(address2) | |
hits1 = vector_search(preprocessed1, top_k=1) | |
hits2 = vector_search(preprocessed2, top_k=1) | |
normalized1 = hits1[0][-1] | |
normalized2 = hits2[0][-1] | |
result = compare(normalized1, normalized2) | |
return result | |
def normalize_address(query_address): | |
with measure('convert_zenkaku_to_hankaku'): | |
query_address = convert_zenkaku_to_hankaku(query_address) | |
with measure('split_address_building_with_gpt'): | |
splitted = split_address_building_with_gpt(query_address) | |
with measure('preprocess'): | |
preprocessed = preprocess(splitted['address']) | |
with measure('vector_search'): | |
hits = vector_search(preprocessed, 1) | |
with measure('split_address'): | |
splits = { | |
'pref': hits[0][3], | |
'county': hits[0][4], | |
'city': hits[0][5], | |
'ward': hits[0][6], | |
'oaza_cho': hits[0][7], | |
'chome': hits[0][8], | |
'koaza': hits[0][9], | |
} | |
with measure('get_addresses_with_parcel'): | |
addresses = get_addresses_with_parcel( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
with measure('get_addresses_with_rsdtdsp'): | |
addresses += get_addresses_with_rsdtdsp( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
addresses = list(set(addresses)) # 重複を除去 | |
with measure('embed_via_multilingual_e5_large'): | |
embeds = embed_via_multilingual_e5_large([splitted['address']] + addresses) | |
query_embed = [embeds[0]] | |
address_embeds = embeds[1:] | |
with measure('cosine'): | |
# コサイン類似度を計算 | |
similarities = cosine_similarity(query_embed, address_embeds) | |
best_match_indices = np.argsort(similarities[0])[-1:][::-1] # 上位Kのインデックスを取得 | |
best_addresses = [addresses[i] for i in best_match_indices] | |
best_address = best_addresses[0] | |
return best_address + splitted['building'] | |
def convert_no_to_hyphen(query_address): | |
return re.sub(r'(?<=\d)の(?=\d)', '-', query_address) | |
def normalize_address_v2(query_address, top_k=1): | |
with measure('convert_zenkaku_to_hankaku'): | |
query_address = convert_zenkaku_to_hankaku(query_address) | |
with measure('split_address_building_with_gpt'): | |
splitted = split_address_building_with_gpt(query_address) | |
with measure('get_spelling'): | |
spelling = get_spelling(splitted['address']) | |
if spelling: | |
splitted['address'] = spelling | |
with measure(''): | |
splitted['address'] = convert_no_to_hyphen(splitted['address']) | |
with measure('preprocess'): | |
preprocessed = preprocess(splitted['address']) | |
with measure('vector_search'): | |
hits = vector_search(preprocessed, 1) | |
with measure('split_address'): | |
splits = { | |
'pref': hits[0][3], | |
'county': hits[0][4], | |
'city': hits[0][5], | |
'ward': hits[0][6], | |
'oaza_cho': hits[0][7], | |
'chome': hits[0][8], | |
'koaza': hits[0][9], | |
} | |
with measure('get_addresses_with_parcel'): | |
addresses = get_addresses_with_parcel( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
with measure('get_addresses_with_rsdtdsp'): | |
addresses += get_addresses_with_rsdtdsp( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
addresses = list(set(addresses)) # 重複を除去 | |
with measure('embed_via_multilingual_e5_large'): | |
embeds = embed_via_multilingual_e5_large([splitted['address']] + addresses) | |
query_embed = [embeds[0]] | |
address_embeds = embeds[1:] | |
with measure('cosine'): | |
# コサイン類似度を計算 | |
similarities = cosine_similarity(query_embed, address_embeds) | |
best_match_indices = np.argsort(similarities[0])[-top_k:][::-1] # 上位Kのインデックスを取得 | |
best_addresses = [addresses[i] for i in best_match_indices] | |
best_similarities = similarities[0][best_match_indices] | |
return splitted, hits, splits, best_addresses, best_similarities | |
# ========================= | |
# FastAPI definition | |
# ========================= | |
from fastapi import FastAPI | |
from pydantic import BaseModel, Field | |
from typing import Literal | |
app = FastAPI( | |
title="住所処理API", | |
description="住所の正規化・比較を行うAPIです。", | |
version="1.0.0" | |
) | |
# --------------------------- | |
# リクエスト・レスポンス定義 | |
# --------------------------- | |
class CompareAddressesRequest(BaseModel): | |
address1: str = Field(..., description="比較する最初の住所", example="東京 墨田区 押上 1丁目1-1") | |
address2: str = Field(..., description="比較する2番目の住所", example="東京 墨田区 押上 1-1-1") | |
class CompareAddressesResponse(BaseModel): | |
result: Literal[True, False] = Field(..., description="比較結果", example=True) | |
class NormalizeAddressRequest(BaseModel): | |
query_address: str = Field(..., description="正規化する住所", example="東京 墨田区 押上 1丁目1-1") | |
class NormalizeAddressResponse(BaseModel): | |
normalized: str = Field(..., description="正規化された住所", example="東京都墨田区押上一丁目1-1") | |
# --------------------------- | |
# エンドポイント定義 | |
# --------------------------- | |
async def compare_two_addresses_api(request: CompareAddressesRequest): | |
""" | |
- **address1**: 比較する最初の住所 | |
- **address2**: 比較する2番目の住所 | |
""" | |
result = compare_two_addresses(request.address1, request.address2) | |
return {"result": result} | |
async def normalize_address_api(request: NormalizeAddressRequest): | |
""" | |
- **query_address**: 正規化する住所 | |
""" | |
normalized = normalize_address(request.query_address) | |
return {"normalized": normalized} | |
async def normalize_address_v2_api(request: NormalizeAddressRequest): | |
""" | |
- **query_address**: 正規化する住所 | |
""" | |
_, __, ___, bests, _____ = normalize_address_v2(request.query_address) | |
return {"normalized": bests[0]} | |
# ========================= | |
# Gradio tabs definition | |
# ========================= | |
examples = [ | |
'東京都中央区みなと3の12の10、プレサンスロゼ東京港301。', | |
'東京都荒川区1−5−6荒川マンション102', | |
'福岡市中央区天神1の11の2', | |
'私の住所は京都府京都市右京区太秦青木元町4-10です。', | |
'京都府京都市右京区太秦青木元町4-10', | |
'京都府京都市右京区太秦青木元町4-10ダックス101号室', | |
'京都府宇治市伊勢田町名木1-1-4ダックス101号室', | |
'東京都渋谷区道玄坂1-12-1', | |
'私の住所は東京都渋谷区道玄坂1-12-1です。', | |
'私の住所は東京都しぶや道玄坂1の12の1です。', | |
'東京都渋谷区道玄坂1の12の1で契約しています。', | |
'秋田県秋田市山王四丁目1番1号です。', | |
'東京 墨田区 押上 1丁目1', | |
'三重県伊勢市宇治館町', | |
'住所は 030-0803 青森県青森市安方1丁目1−40になります。', | |
'東京都大島町差木地 字クダッチ', | |
'前橋市大手町1丁目1番地1', | |
'東京都渋谷区表参道の3の5の6。', | |
'琉球圏尾張町3の5の6に住んでます。', | |
'3254987の場所です。', | |
'大阪府でした。', | |
'1940923の東京都渋谷区道玄坂一丁目。渋谷マークシティウェスト23階です。', | |
'名前は山田太郎です。', | |
'はい。名古屋、あ、愛知県名古屋市南里2の3の4だと思います。', | |
'ー', | |
'少し待ってください。', | |
] | |
def create_function_test_tab(): | |
def create_remove_left_of_pref_tab(): | |
with gr.Tab("remove_left_of_pref"): | |
in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=remove_left_of_pref, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_replace_circle_tab(): | |
with gr.Tab("replace_circle"): | |
in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=replace_circle, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_remove_filler_tab(): | |
with gr.Tab("remove_filler"): | |
in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=remove_filler, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_preprocess_tab(): | |
with gr.Tab("preprocess"): | |
in_tb = gr.Textbox(label='インプット', placeholder='テキストを入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=preprocess, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_compare_two_addresses_tab(): | |
with gr.Tab("compare_two_addresses"): | |
in_tb1 = gr.Textbox(label='住所1 (顧客が発言した住所)', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb1]) | |
in_tb2 = gr.Textbox(label='住所2 (CRM 内に格納されている住所)', placeholder='住所を入力してください') | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=compare_two_addresses, | |
inputs=[in_tb1, in_tb2], | |
outputs=[out_tb], | |
) | |
def create_normalize_address_tab(): | |
with gr.Tab("normalize_address"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=normalize_address, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_normalize_address__v2_tab(): | |
with gr.Tab("normalize_address_v2"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
def f(query_address): | |
splitted, __, ___, bests, _____ = normalize_address_v2(query_address) | |
return bests[0] + splitted['building'] | |
exe_button.click( | |
fn=f, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_split_address_tab(): | |
with gr.Tab("split_address"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=split_address, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_split_address_building_tab(): | |
with gr.Tab("split_address_building_with_gpt"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=split_address_building_with_gpt, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_convert_zenkaku_to_hankaku_tab(): | |
with gr.Tab("convert_zenkaku_to_hankaku"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=convert_zenkaku_to_hankaku, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
def create_vector_search(): | |
def f(query_address, top_k): | |
with measure('preprocess'): | |
preprocessed = preprocess(query_address) | |
with measure('vector_search'): | |
hits = vector_search(preprocessed, top_k=top_k) | |
return pd.DataFrame(hits, columns=['Top-k', '類似度', '住所', '都道府県', '郡', '市区町村', '政令市区', '大字・町', '丁目', '小字']) | |
with gr.Tab("vector_search"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
top_k_input = gr.Slider(minimum=1, maximum=100, step=1, value=5, label='検索数top-k') | |
out_df = gr.Dataframe(label="アウトプット", wrap=True) | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=f, | |
inputs=[in_tb, top_k_input], | |
outputs=[out_df], | |
) | |
def create_get_addresses_with_parcel_tab(): | |
def f(query_address): | |
with measure('preprocess'): | |
preprocessed = preprocess(query_address) | |
with measure('vector_search'): | |
hits = vector_search(preprocessed, top_k=1) | |
with measure('split_address'): | |
splits = { | |
'pref': hits[0][3], | |
'county': hits[0][4], | |
'city': hits[0][5], | |
'ward': hits[0][6], | |
'oaza_cho': hits[0][7], | |
'chome': hits[0][8], | |
'koaza': hits[0][9], | |
} | |
with measure('get_addresses_with_parcel'): | |
addresses = get_addresses_with_parcel( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
return pd.DataFrame(addresses, columns=['住所']) | |
with gr.Tab("get_addresses_with_parcel"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_df = gr.Dataframe(label="アウトプット", wrap=True) | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=f, | |
inputs=[in_tb], | |
outputs=[out_df], | |
) | |
def create_get_spelling_tab(): | |
with gr.Tab("create_get_spelling_tab"): | |
in_tb = gr.Textbox(label='住所', placeholder='住所を入力してください') | |
gr.Examples(examples=examples, inputs=[in_tb]) | |
out_tb = gr.Textbox(label='アウトプット') | |
exe_button = gr.Button(value='実行', variant='primary') | |
exe_button.click( | |
fn=get_spelling, | |
inputs=[in_tb], | |
outputs=[out_tb], | |
) | |
with gr.Tab("関数テスト"): | |
create_normalize_address_tab() | |
create_normalize_address__v2_tab() | |
create_compare_two_addresses_tab() | |
create_get_spelling_tab() | |
create_get_addresses_with_parcel_tab() | |
create_vector_search() | |
create_remove_left_of_pref_tab() | |
create_replace_circle_tab() | |
create_remove_filler_tab() | |
create_preprocess_tab() | |
create_split_address_tab() | |
create_split_address_building_tab() | |
create_convert_zenkaku_to_hankaku_tab() | |
def create_digital_agency_tab(): | |
with gr.Tab("デジ庁API"): | |
with gr.Row(): | |
with gr.Column(): | |
address_input_tab2 = gr.Textbox(label='住所', placeholder='検索したい住所を入力してください') | |
gr.Examples(examples=examples, inputs=[address_input_tab2]) | |
search_button_tab2 = gr.Button(value='検索', variant='primary') | |
result_tb = gr.Textbox(label='正規化後') | |
result_df = gr.Dataframe(label="正規化後(分割)", wrap=True) | |
def normalize_address_via_abrg_geocode(query_address): | |
query_address = preprocess(query_address) | |
url = f'{ABRG_ENDPOINT}/geocode?address={query_address}' | |
response = requests.get(url) | |
result = response.json()[0]['result'] | |
normalized = result['output'] | |
data = { | |
'pref': result['pref'], | |
'county': result['county'], | |
'city': result['city'], | |
'ward': result['ward'], | |
'oaza_cho': result['oaza_cho'], | |
'chome': result['chome'], | |
'koaza': result['koaza'], | |
'blk_num': result['blk_num'], | |
'rsdt_num': result['rsdt_num'], | |
'rsdt_num2': result['rsdt_num2'], | |
'prc_num1': result['prc_num1'], | |
'prc_num2': result['prc_num2'], | |
'prc_num3': result['prc_num3'], | |
'others': ''.join(result['others']) | |
} | |
df = pd.DataFrame([data]) | |
return normalized, df | |
search_button_tab2.click( | |
fn=normalize_address_via_abrg_geocode, | |
inputs=[address_input_tab2], | |
outputs=[result_tb, result_df], | |
) | |
def create_vector_search_tab(): | |
with gr.Tab("ベクトル検索"): | |
with gr.Row(): | |
with gr.Column(): | |
address_input = gr.Textbox(label='住所', placeholder='検索したい住所を入力してください') | |
gr.Examples(examples=examples, inputs=[address_input]) | |
top_k_input = gr.Slider(minimum=1, maximum=100, step=1, value=5, label='検索数top-k') | |
search_button = gr.Button(value='検索', variant='primary') | |
result_tb = gr.Textbox(label='正規化後') | |
result_df = gr.Dataframe(label="正規化後(分割)", wrap=True) | |
search_result_df = gr.Dataframe(label="町丁目まで検索結果") | |
chiban_result_df = gr.Dataframe(label="地番・住居表示検索結果") | |
def search_address(query_address, top_k): | |
with measure('convert_zenkaku_to_hankaku'): | |
query_address = convert_zenkaku_to_hankaku(query_address) | |
with measure('split_address_building_with_gpt'): | |
splitted = split_address_building_with_gpt(query_address) | |
with measure('preprocess'): | |
preprocessed = preprocess(splitted['address']) | |
with measure('vector_search'): | |
hits = vector_search(preprocessed, top_k) | |
search_result_df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所', '都道府県', '郡', '市区町村', '政令市区', '大字・町', '丁目', '小字']) | |
with measure('split_address'): | |
splits = { | |
'pref': hits[0][3], | |
'county': hits[0][4], | |
'city': hits[0][5], | |
'ward': hits[0][6], | |
'oaza_cho': hits[0][7], | |
'chome': hits[0][8], | |
'koaza': hits[0][9], | |
} | |
result_df = pd.DataFrame([splits.values()], columns=splits.keys()) | |
with measure('get_addresses_with_parcel'): | |
addresses = get_addresses_with_parcel( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
with measure('get_addresses_with_rsdtdsp'): | |
addresses += get_addresses_with_rsdtdsp( | |
splits['pref'], splits['county'], splits['city'], splits['ward'], | |
splits['oaza_cho'], splits['chome'], splits['koaza']) | |
addresses = list(set(addresses)) # 重複を除去 | |
with measure('embed_via_multilingual_e5_large'): | |
embeds = embed_via_multilingual_e5_large([splitted['address']] + addresses) | |
query_embed = [embeds[0]] | |
address_embeds = embeds[1:] | |
with measure('cosine'): | |
# コサイン類似度を計算 | |
similarities = cosine_similarity(query_embed, address_embeds) | |
best_match_indices = np.argsort(similarities[0])[-top_k:][::-1] # 上位Kのインデックスを取得 | |
best_addresses = [addresses[i] for i in best_match_indices] | |
best_similarities = similarities[0][best_match_indices] | |
print(top_k) | |
print('len(best_similarities)', len(best_similarities)) | |
print('len(best_addresses)', len(best_addresses)) | |
chiban_result_df = pd.DataFrame({ | |
'Top-k': range(1, len(best_similarities) + 1), | |
'類似度': best_similarities, | |
'住所': [best_address + splitted['building'] for best_address in best_addresses] | |
}) | |
best_address = best_addresses[0] + splitted['building'] | |
return search_result_df, chiban_result_df, best_address, result_df | |
search_button.click( | |
fn=search_address, | |
inputs=[address_input, top_k_input], | |
outputs=[search_result_df, chiban_result_df, result_tb, result_df], | |
) | |
def create_vector_search_v2_tab(): | |
with gr.Tab("ベクトル検索V2"): | |
with gr.Row(): | |
with gr.Column(): | |
address_input = gr.Textbox(label='住所', placeholder='検索したい住所を入力してください') | |
gr.Examples(examples=examples, inputs=[address_input]) | |
top_k_input = gr.Slider(minimum=1, maximum=100, step=1, value=5, label='検索数top-k') | |
search_button = gr.Button(value='検索', variant='primary') | |
result_tb = gr.Textbox(label='正規化後') | |
result_df = gr.Dataframe(label="正規化後(分割)", wrap=True) | |
search_result_df = gr.Dataframe(label="町丁目まで検索結果") | |
chiban_result_df = gr.Dataframe(label="地番・住居表示検索結果") | |
def search_address(query_address, top_k): | |
splitted, hits, splits, best_addresses, best_similarities = normalize_address_v2(query_address, top_k) | |
search_result_df = pd.DataFrame(hits, columns=['Top-k', '類似度', '住所', '都道府県', '郡', '市区町村', '政令市区', '大字・町', '丁目', '小字']) | |
result_df = pd.DataFrame([splits.values()], columns=splits.keys()) | |
chiban_result_df = pd.DataFrame({ | |
'Top-k': range(1, len(best_similarities) + 1), | |
'類似度': best_similarities, | |
'住所': [best_address + splitted['building'] for best_address in best_addresses] | |
}) | |
best_address = best_addresses[0] + splitted['building'] | |
return search_result_df, chiban_result_df, best_address, result_df | |
search_button.click( | |
fn=search_address, | |
inputs=[address_input, top_k_input], | |
outputs=[search_result_df, chiban_result_df, result_tb, result_df], | |
) | |
with gr.Blocks() as demo: | |
create_function_test_tab() | |
create_vector_search_tab() | |
create_vector_search_v2_tab() | |
create_digital_agency_tab() | |
app = gr.mount_gradio_app(app, demo, path='/') |