meisaicheck-api / routes /predict.py
vumichien's picture
update project structure
b77c0a2
raw
history blame
3.12 kB
import os
import time
import shutil
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
from fastapi.responses import FileResponse
from auth import get_current_user
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
from data_lib.input_name_data import InputNameData
from mapping_lib.name_mapping_helper import NameMappingHelper
from config import UPLOAD_DIR, OUTPUT_DIR
router = APIRouter()
@router.post("/predict")
async def predict(
current_user=Depends(get_current_user),
file: UploadFile = File(...),
sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service)
):
"""
Process an input CSV file and return standardized names (requires authentication)
"""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files are supported")
# Save uploaded file
timestamp = int(time.time())
input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv")
output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv")
try:
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
finally:
file.file.close()
try:
# Process input data
inputData = InputNameData(sentence_service.dic_standard_subject)
inputData.load_data_from_csv(input_file_path)
inputData.process_data()
# Map standard names
nameMappingHelper = NameMappingHelper(
sentence_service.sentenceTransformerHelper,
inputData,
sentence_service.sampleData,
sentence_service.sample_name_sentence_embeddings,
sentence_service.sample_name_sentence_similarities,
)
df_predicted = nameMappingHelper.map_standard_names()
# Create output dataframe and save to CSV
column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考']
output_df = inputData.dataframe[column_to_keep].copy()
output_df.reset_index(drop=False, inplace=True)
output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"]
output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"]
output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"]
# Save with utf_8_sig encoding for Japanese Excel compatibility
output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig")
return FileResponse(
path=output_file_path,
filename=f"output_{Path(file.filename).stem}.csv",
media_type="text/csv",
headers={
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
"Content-Type": "application/x-www-form-urlencoded",
},
)
except Exception as e:
print(f"Error processing file: {e}")
raise HTTPException(status_code=500, detail=str(e))