Spaces:
Sleeping
Sleeping
住所データをダウンロードする機能を追加し、環境変数からターゲットディレクトリを取得するように修正。新たに必要なディレクトリを.gitignoreに追加し、一時ディレクトリの管理を強化。
Browse files- .gitignore +3 -1
- app.py +72 -6
.gitignore
CHANGED
@@ -2,4 +2,6 @@
|
|
2 |
embeddings/
|
3 |
embeddings_/
|
4 |
__pycache__/
|
5 |
-
.env
|
|
|
|
|
|
2 |
embeddings/
|
3 |
embeddings_/
|
4 |
__pycache__/
|
5 |
+
.env
|
6 |
+
data/
|
7 |
+
digital_agency/
|
app.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
from pathlib import Path
|
4 |
import spacy
|
@@ -21,7 +23,8 @@ load_dotenv()
|
|
21 |
# =========================
|
22 |
# Global variables
|
23 |
# =========================
|
24 |
-
|
|
|
25 |
HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
|
26 |
EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
|
27 |
ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
|
@@ -34,6 +37,71 @@ VECTOR_SEARCH_COLLECTION_NAME_V2 = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME
|
|
34 |
MILVUS_CLIENT = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
|
35 |
print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# =========================
|
38 |
# Utilitiy functions
|
39 |
# =========================
|
@@ -392,8 +460,7 @@ def create_vector_search_tab():
|
|
392 |
}
|
393 |
result_df = pd.DataFrame([splits.values()], columns=splits.keys())
|
394 |
with measure('load city_all_file'):
|
395 |
-
|
396 |
-
city_all_file = target_dir / 'mt_city_all.csv'
|
397 |
city_all_df = pd.read_csv(city_all_file)
|
398 |
city_all_df_temp = city_all_df[city_all_df['pref'] == splits['pref']]
|
399 |
city_name1 = city_all_df_temp['county'].fillna('') + city_all_df_temp['city'].fillna('') + city_all_df_temp['ward'].fillna('')
|
@@ -403,7 +470,7 @@ def create_vector_search_tab():
|
|
403 |
raise Exception('Too many lg_code')
|
404 |
lg_code = lg_codes[0]
|
405 |
with measure('load parcel_city_file'):
|
406 |
-
parcel_city_file =
|
407 |
if not os.path.exists(parcel_city_file):
|
408 |
# raise gr.Error('Too many lg_code')
|
409 |
raise Exception('Too many lg_code')
|
@@ -446,8 +513,7 @@ def create_vector_search_tab():
|
|
446 |
]
|
447 |
with measure('load rsdtdsp_file'):
|
448 |
pref_code = ('%06d' % lg_code)[0:2]
|
449 |
-
|
450 |
-
rsdtdsp_file = rsdtdsp_dir / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv\mt_rsdtdsp_rsdt_pref{pref_code}.csv'
|
451 |
if not os.path.exists(rsdtdsp_file):
|
452 |
# raise gr.Error(f'Not found: {rsdtdsp_file}')
|
453 |
raise Exception(f'Not found: {rsdtdsp_file}')
|
|
|
1 |
import gradio as gr
|
2 |
+
import zipfile
|
3 |
+
from tqdm import tqdm
|
4 |
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
from pathlib import Path
|
6 |
import spacy
|
|
|
23 |
# =========================
|
24 |
# Global variables
|
25 |
# =========================
|
26 |
+
TARGET_DIR = Path(os.environ.get('TARGET_DIR'))
|
27 |
+
|
28 |
HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
|
29 |
EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
|
30 |
ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
|
|
|
37 |
MILVUS_CLIENT = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
|
38 |
print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
|
39 |
|
40 |
+
# ----------------------------
|
41 |
+
# Download mt_city_all.csv
|
42 |
+
# ----------------------------
|
43 |
+
temp_dir = Path('temp')
|
44 |
+
temp_dir.mkdir(exist_ok=True)
|
45 |
+
|
46 |
+
city_all_url = 'https://catalog.registries.digital.go.jp/rsc/address/mt_city_all.csv.zip'
|
47 |
+
zip_file_path = temp_dir / 'mt_city_all.csv.zip'
|
48 |
+
|
49 |
+
# すでにファイルが存在する場合はダウンロードをスキップ
|
50 |
+
if not os.path.exists(zip_file_path):
|
51 |
+
# ZIPファイルをダウンロード
|
52 |
+
response = requests.get(city_all_url)
|
53 |
+
with open(zip_file_path, 'wb') as f:
|
54 |
+
f.write(response.content)
|
55 |
+
|
56 |
+
# target_dir直下に解凍
|
57 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
58 |
+
zip_ref.extractall(TARGET_DIR)
|
59 |
+
|
60 |
+
# ------------------------------------
|
61 |
+
# Download mt_parcel_cityXXXXXX.csv
|
62 |
+
# ------------------------------------
|
63 |
+
city_all_file = TARGET_DIR / 'mt_city_all.csv'
|
64 |
+
city_all_df = pd.read_csv(city_all_file)
|
65 |
+
lg_codes = city_all_df['lg_code'].tolist()
|
66 |
+
print('lg_codes', len(lg_codes))
|
67 |
+
|
68 |
+
for lg_code in tqdm(lg_codes):
|
69 |
+
parcel_url = f'https://catalog.registries.digital.go.jp/rsc/address/mt_parcel_city{lg_code:06d}.csv.zip'
|
70 |
+
zip_file_path = temp_dir / f'mt_parcel_city{lg_code:06d}.csv.zip'
|
71 |
+
|
72 |
+
if not os.path.exists(TARGET_DIR / 'parcel' / f'mt_parcel_city{lg_code:06d}.csv'):
|
73 |
+
response = requests.get(parcel_url)
|
74 |
+
if response.status_code == 200: # URLが存在する場合のみ処理を続ける
|
75 |
+
with open(zip_file_path, 'wb') as f:
|
76 |
+
f.write(response.content)
|
77 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
78 |
+
zip_ref.extractall(TARGET_DIR / 'parcel')
|
79 |
+
time.sleep(0.2) # ダウンロードごとに200msのスリープを入れる
|
80 |
+
|
81 |
+
# ------------------------------------
|
82 |
+
# Download mt_rsdtdsp_rsdt_prefXX.csv
|
83 |
+
# ------------------------------------
|
84 |
+
pref_codes = list(set([('%06d' % lg_code)[0:2] for lg_code in lg_codes]))
|
85 |
+
print('pref_codes', len(pref_codes))
|
86 |
+
|
87 |
+
for pref_code in tqdm(pref_codes):
|
88 |
+
rsdt_url = f'https://catalog.registries.digital.go.jp/rsc/address/mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip'
|
89 |
+
zip_file_path = temp_dir / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip'
|
90 |
+
|
91 |
+
if not os.path.exists(TARGET_DIR / 'rsdt' / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip'):
|
92 |
+
response = requests.get(parcel_url)
|
93 |
+
if response.status_code == 200: # URLが存在する場合のみ処理を続ける
|
94 |
+
with open(zip_file_path, 'wb') as f:
|
95 |
+
f.write(response.content)
|
96 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
97 |
+
zip_ref.extractall(TARGET_DIR / 'rsdt')
|
98 |
+
time.sleep(0.2) # ダウンロードごとに200msのスリープを入れる
|
99 |
+
|
100 |
+
# 一時ディレクトリを削除
|
101 |
+
for file in temp_dir.iterdir():
|
102 |
+
file.unlink()
|
103 |
+
temp_dir.rmdir()
|
104 |
+
|
105 |
# =========================
|
106 |
# Utilitiy functions
|
107 |
# =========================
|
|
|
460 |
}
|
461 |
result_df = pd.DataFrame([splits.values()], columns=splits.keys())
|
462 |
with measure('load city_all_file'):
|
463 |
+
city_all_file = TARGET_DIR / 'mt_city_all.csv'
|
|
|
464 |
city_all_df = pd.read_csv(city_all_file)
|
465 |
city_all_df_temp = city_all_df[city_all_df['pref'] == splits['pref']]
|
466 |
city_name1 = city_all_df_temp['county'].fillna('') + city_all_df_temp['city'].fillna('') + city_all_df_temp['ward'].fillna('')
|
|
|
470 |
raise Exception('Too many lg_code')
|
471 |
lg_code = lg_codes[0]
|
472 |
with measure('load parcel_city_file'):
|
473 |
+
parcel_city_file = TARGET_DIR / 'parcel' / f'mt_parcel_city{lg_code:06d}.csv'
|
474 |
if not os.path.exists(parcel_city_file):
|
475 |
# raise gr.Error('Too many lg_code')
|
476 |
raise Exception('Too many lg_code')
|
|
|
513 |
]
|
514 |
with measure('load rsdtdsp_file'):
|
515 |
pref_code = ('%06d' % lg_code)[0:2]
|
516 |
+
rsdtdsp_file = RSDTDSP_DIR / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv'
|
|
|
517 |
if not os.path.exists(rsdtdsp_file):
|
518 |
# raise gr.Error(f'Not found: {rsdtdsp_file}')
|
519 |
raise Exception(f'Not found: {rsdtdsp_file}')
|