matsuap commited on
Commit
65381d1
·
1 Parent(s): 7dffd7c

住所データをダウンロードする機能を追加し、環境変数からターゲットディレクトリを取得するように修正。新たに必要なディレクトリを.gitignoreに追加し、一時ディレクトリの管理を強化。

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. app.py +72 -6
.gitignore CHANGED
@@ -2,4 +2,6 @@
2
  embeddings/
3
  embeddings_/
4
  __pycache__/
5
- .env
 
 
 
2
  embeddings/
3
  embeddings_/
4
  __pycache__/
5
+ .env
6
+ data/
7
+ digital_agency/
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  from sklearn.metrics.pairwise import cosine_similarity
3
  from pathlib import Path
4
  import spacy
@@ -21,7 +23,8 @@ load_dotenv()
21
  # =========================
22
  # Global variables
23
  # =========================
24
- # 環境変数からHUGGING_FACE_TOKENを取得
 
25
  HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
26
  EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
27
  ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
@@ -34,6 +37,71 @@ VECTOR_SEARCH_COLLECTION_NAME_V2 = os.environ.get('VECTOR_SEARCH_COLLECTION_NAME
34
  MILVUS_CLIENT = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
35
  print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # =========================
38
  # Utilitiy functions
39
  # =========================
@@ -392,8 +460,7 @@ def create_vector_search_tab():
392
  }
393
  result_df = pd.DataFrame([splits.values()], columns=splits.keys())
394
  with measure('load city_all_file'):
395
- target_dir = Path(r'C:\Users\taish\Development\whispercustom\projects\abr-geocoder\temp\download')
396
- city_all_file = target_dir / 'mt_city_all.csv'
397
  city_all_df = pd.read_csv(city_all_file)
398
  city_all_df_temp = city_all_df[city_all_df['pref'] == splits['pref']]
399
  city_name1 = city_all_df_temp['county'].fillna('') + city_all_df_temp['city'].fillna('') + city_all_df_temp['ward'].fillna('')
@@ -403,7 +470,7 @@ def create_vector_search_tab():
403
  raise Exception('Too many lg_code')
404
  lg_code = lg_codes[0]
405
  with measure('load parcel_city_file'):
406
- parcel_city_file = target_dir / f'mt_parcel_city{lg_code:06d}.csv'
407
  if not os.path.exists(parcel_city_file):
408
  # raise gr.Error('Too many lg_code')
409
  raise Exception('Too many lg_code')
@@ -446,8 +513,7 @@ def create_vector_search_tab():
446
  ]
447
  with measure('load rsdtdsp_file'):
448
  pref_code = ('%06d' % lg_code)[0:2]
449
- rsdtdsp_dir = Path(rf'G:\マイドライブ\Development\Dataset\Misc\japanese_address\rsdt\original')
450
- rsdtdsp_file = rsdtdsp_dir / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv\mt_rsdtdsp_rsdt_pref{pref_code}.csv'
451
  if not os.path.exists(rsdtdsp_file):
452
  # raise gr.Error(f'Not found: {rsdtdsp_file}')
453
  raise Exception(f'Not found: {rsdtdsp_file}')
 
1
  import gradio as gr
2
+ import zipfile
3
+ from tqdm import tqdm
4
  from sklearn.metrics.pairwise import cosine_similarity
5
  from pathlib import Path
6
  import spacy
 
23
  # =========================
24
  # Global variables
25
  # =========================
26
+ TARGET_DIR = Path(os.environ.get('TARGET_DIR'))
27
+
28
  HUGGING_FACE_TOKEN = os.environ.get('HUGGING_FACE_TOKEN')
29
  EMBEDDING_MODEL_ENDPOINT = os.environ.get('EMBEDDING_MODEL_ENDPOINT')
30
  ABRG_ENDPOINT = os.environ.get('ABRG_ENDPOINT')
 
37
  MILVUS_CLIENT = MilvusClient(uri=VECTOR_SEARCH_ENDPOINT, token=VECTOR_SEARCH_TOKEN)
38
  print(f"Connected to DB: {VECTOR_SEARCH_ENDPOINT} successfully")
39
 
40
+ # ----------------------------
41
+ # Download mt_city_all.csv
42
+ # ----------------------------
43
+ temp_dir = Path('temp')
44
+ temp_dir.mkdir(exist_ok=True)
45
+
46
+ city_all_url = 'https://catalog.registries.digital.go.jp/rsc/address/mt_city_all.csv.zip'
47
+ zip_file_path = temp_dir / 'mt_city_all.csv.zip'
48
+
49
+ # すでにファイルが存在する場合はダウンロードをスキップ
50
+ if not os.path.exists(zip_file_path):
51
+ # ZIPファイルをダウンロード
52
+ response = requests.get(city_all_url)
53
+ with open(zip_file_path, 'wb') as f:
54
+ f.write(response.content)
55
+
56
+ # target_dir直下に解凍
57
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
58
+ zip_ref.extractall(TARGET_DIR)
59
+
60
+ # ------------------------------------
61
+ # Download mt_parcel_cityXXXXXX.csv
62
+ # ------------------------------------
63
+ city_all_file = TARGET_DIR / 'mt_city_all.csv'
64
+ city_all_df = pd.read_csv(city_all_file)
65
+ lg_codes = city_all_df['lg_code'].tolist()
66
+ print('lg_codes', len(lg_codes))
67
+
68
+ for lg_code in tqdm(lg_codes):
69
+ parcel_url = f'https://catalog.registries.digital.go.jp/rsc/address/mt_parcel_city{lg_code:06d}.csv.zip'
70
+ zip_file_path = temp_dir / f'mt_parcel_city{lg_code:06d}.csv.zip'
71
+
72
+ if not os.path.exists(TARGET_DIR / 'parcel' / f'mt_parcel_city{lg_code:06d}.csv'):
73
+ response = requests.get(parcel_url)
74
+ if response.status_code == 200: # URLが存在する場合のみ処理を続ける
75
+ with open(zip_file_path, 'wb') as f:
76
+ f.write(response.content)
77
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
78
+ zip_ref.extractall(TARGET_DIR / 'parcel')
79
+ time.sleep(0.2) # ダウンロードごとに200msのスリープを入れる
80
+
81
+ # ------------------------------------
82
+ # Download mt_rsdtdsp_rsdt_prefXX.csv
83
+ # ------------------------------------
84
+ pref_codes = list(set([('%06d' % lg_code)[0:2] for lg_code in lg_codes]))
85
+ print('pref_codes', len(pref_codes))
86
+
87
+ for pref_code in tqdm(pref_codes):
88
+ rsdt_url = f'https://catalog.registries.digital.go.jp/rsc/address/mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip'
89
+ zip_file_path = temp_dir / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip'
90
+
91
+ if not os.path.exists(TARGET_DIR / 'rsdt' / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv.zip'):
92
+ response = requests.get(parcel_url)
93
+ if response.status_code == 200: # URLが存在する場合のみ処理を続ける
94
+ with open(zip_file_path, 'wb') as f:
95
+ f.write(response.content)
96
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
97
+ zip_ref.extractall(TARGET_DIR / 'rsdt')
98
+ time.sleep(0.2) # ダウンロードごとに200msのスリープを入れる
99
+
100
+ # 一時ディレクトリを削除
101
+ for file in temp_dir.iterdir():
102
+ file.unlink()
103
+ temp_dir.rmdir()
104
+
105
  # =========================
106
  # Utilitiy functions
107
  # =========================
 
460
  }
461
  result_df = pd.DataFrame([splits.values()], columns=splits.keys())
462
  with measure('load city_all_file'):
463
+ city_all_file = TARGET_DIR / 'mt_city_all.csv'
 
464
  city_all_df = pd.read_csv(city_all_file)
465
  city_all_df_temp = city_all_df[city_all_df['pref'] == splits['pref']]
466
  city_name1 = city_all_df_temp['county'].fillna('') + city_all_df_temp['city'].fillna('') + city_all_df_temp['ward'].fillna('')
 
470
  raise Exception('Too many lg_code')
471
  lg_code = lg_codes[0]
472
  with measure('load parcel_city_file'):
473
+ parcel_city_file = TARGET_DIR / 'parcel' / f'mt_parcel_city{lg_code:06d}.csv'
474
  if not os.path.exists(parcel_city_file):
475
  # raise gr.Error('Too many lg_code')
476
  raise Exception('Too many lg_code')
 
513
  ]
514
  with measure('load rsdtdsp_file'):
515
  pref_code = ('%06d' % lg_code)[0:2]
516
+ rsdtdsp_file = RSDTDSP_DIR / f'mt_rsdtdsp_rsdt_pref{pref_code}.csv'
 
517
  if not os.path.exists(rsdtdsp_file):
518
  # raise gr.Error(f'Not found: {rsdtdsp_file}')
519
  raise Exception(f'Not found: {rsdtdsp_file}')