meisaicheck-api / routes /predict.py
vumichien's picture
Refactor unit mapping in prediction process by replacing UnitMapper with UnitSimilarityMapper for improved similarity calculations and error handling.
3020335
raw
history blame
22.8 kB
import os
import time
import shutil
import pandas as pd
import traceback
import sys
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
from fastapi.responses import FileResponse
from custom_auth import get_current_user_from_token
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
# Add the path to import modules from meisai-check-ai
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "meisai-check-ai"))
from mapping_lib.standard_subject_data_mapper import StandardSubjectDataMapper
from mapping_lib.subject_similarity_mapper import SubjectSimilarityMapper
from mapping_lib.sub_subject_similarity_mapper import SubSubjectSimilarityMapper
from mapping_lib.name_similarity_mapper import NameSimilarityMapper
from mapping_lib.sub_subject_and_name_data_mapper import SubSubjectAndNameDataMapper
from mapping_lib.abstract_similarity_mapper import AbstractSimilarityMapper
from mapping_lib.name_and_abstract_mapper import NameAndAbstractDataMapper
from mapping_lib.unit_similarity_mapper import UnitSimilarityMapper
from mapping_lib.standard_name_mapper import StandardNameMapper
from config import UPLOAD_DIR, OUTPUT_DIR
from models import (
EmbeddingRequest,
PredictRawRequest,
PredictRawResponse,
PredictRecord,
PredictResult,
)
router = APIRouter()
@router.post("/predict")
async def predict(
current_user=Depends(get_current_user_from_token),
file: UploadFile = File(...),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Process an input CSV file and return standardized names (requires authentication)
"""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files are supported")
# Save uploaded file
timestamp = int(time.time())
input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv")
output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv")
try:
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
finally:
file.file.close()
try:
# Load input data
start_time = time.time()
df_input_data = pd.read_csv(input_file_path)
# Ensure basic columns exist with default values
basic_columns = {
"シート名": "",
"行": "",
"科目": "",
"中科目": "",
"分類": "",
"名称": "",
"単位": "",
"摘要": "",
"備考": "",
}
for col, default_value in basic_columns.items():
if col not in df_input_data.columns:
df_input_data[col] = default_value
# Process data using the new mapping system similar to predict.py
try:
# Subject mapping
if sentence_service.df_subject_map_data is not None:
subject_similarity_mapper = SubjectSimilarityMapper(
cached_embedding_helper=sentence_service.subject_cached_embedding_helper,
df_map_data=sentence_service.df_subject_map_data,
)
list_input_subject = df_input_data["科目"].unique()
df_subject_data = pd.DataFrame({"科目": list_input_subject})
subject_similarity_mapper.predict_input(df_input_data=df_subject_data)
output_subject_map = dict(
zip(df_subject_data["科目"], df_subject_data["出力_科目"])
)
df_input_data["標準科目"] = df_input_data["科目"].map(
output_subject_map
)
df_input_data["出力_科目"] = df_input_data["科目"].map(
output_subject_map
)
except Exception as e:
print(f"Error processing SubjectSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Standard subject mapping
if sentence_service.df_standard_subject_map_data is not None:
standard_subject_data_mapper = StandardSubjectDataMapper(
df_map_data=sentence_service.df_standard_subject_map_data
)
df_output_data = standard_subject_data_mapper.map_data(
df_input_data=df_input_data,
input_key_columns=["出力_科目"],
in_place=True,
)
else:
df_output_data = df_input_data.copy()
except Exception as e:
print(f"Error processing StandardSubjectDataMapper: {e}")
# Continue with original data if standard subject mapping fails
df_output_data = df_input_data.copy()
try:
# Sub subject mapping
if sentence_service.df_sub_subject_map_data is not None:
sub_subject_similarity_mapper = SubSubjectSimilarityMapper(
cached_embedding_helper=sentence_service.sub_subject_cached_embedding_helper,
df_map_data=sentence_service.df_sub_subject_map_data,
)
sub_subject_similarity_mapper.predict_input(
df_input_data=df_output_data
)
df_output_data = df_output_data.fillna("")
except Exception as e:
print(f"Error processing SubSubjectSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Name mapping
if sentence_service.df_name_map_data is not None:
name_sentence_mapper = NameSimilarityMapper(
cached_embedding_helper=sentence_service.name_cached_embedding_helper,
df_map_data=sentence_service.df_name_map_data,
)
name_sentence_mapper.predict_input(df_input_data=df_output_data)
except Exception as e:
print(f"Error processing NameSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Sub subject and name mapping
if sentence_service.df_sub_subject_and_name_map_data is not None:
sub_subject_and_name_mapper = SubSubjectAndNameDataMapper(
df_map_data=sentence_service.df_sub_subject_and_name_map_data
)
sub_subject_and_name_mapper.map_data(df_input_data=df_output_data)
except Exception as e:
print(f"Error processing SubSubjectAndNameDataMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Abstract mapping
if sentence_service.df_abstract_map_data is not None:
# Ensure required columns exist before AbstractSimilarityMapper
required_columns_for_abstract = {
"標準科目": "",
"摘要グループ": "",
"確定": "未確定",
"摘要": "",
"備考": "",
}
# Add missing columns with appropriate defaults
for col, default_val in required_columns_for_abstract.items():
if col not in df_output_data.columns:
df_output_data[col] = default_val
print(
f"DEBUG: Added missing column '{col}' with default value '{default_val}'"
)
# Ensure data types are correct (convert to string to avoid type issues)
for col in ["標準科目", "摘要グループ", "確定", "摘要", "備考"]:
if col in df_output_data.columns:
df_output_data[col] = df_output_data[col].astype(str).fillna("")
abstract_similarity_mapper = AbstractSimilarityMapper(
cached_embedding_helper=sentence_service.abstract_cached_embedding_helper,
df_map_data=sentence_service.df_abstract_map_data,
)
abstract_similarity_mapper.predict_input(df_input_data=df_output_data)
print(f"DEBUG: AbstractSimilarityMapper completed successfully")
except Exception as e:
print(f"Error processing AbstractSimilarityMapper: {e}")
print(f"DEBUG: Full error traceback:")
import traceback
traceback.print_exc()
# Don't raise the exception, continue processing
print(f"DEBUG: Continuing without AbstractSimilarityMapper...")
try:
# Name and abstract mapping
if sentence_service.df_name_and_subject_map_data is not None:
name_and_abstract_mapper = NameAndAbstractDataMapper(
df_map_data=sentence_service.df_name_and_subject_map_data
)
df_output_data = name_and_abstract_mapper.map_data(df_output_data)
except Exception as e:
print(f"Error processing NameAndAbstractDataMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Unit mapping
if sentence_service.df_unit_map_data is not None:
unit_mapper = UnitSimilarityMapper(
cached_embedding_helper=sentence_service.unit_cached_embedding_helper,
df_map_data=sentence_service.df_unit_map_data,
)
unit_mapper.predict_input(df_input_data=df_output_data)
except Exception as e:
print(f"Error processing UnitMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Standard name mapping
if sentence_service.df_standard_name_map_data is not None:
standard_name_mapper = StandardNameMapper(
df_map_data=sentence_service.df_standard_name_map_data
)
df_output_data = standard_name_mapper.map_data(df_output_data)
except Exception as e:
print(f"Error processing StandardNameMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Create output columns and ensure they have proper values
# Add ID column if not exists
if "ID" not in df_output_data.columns:
df_output_data.reset_index(drop=False, inplace=True)
df_output_data.rename(columns={"index": "ID"}, inplace=True)
df_output_data["ID"] = df_output_data["ID"] + 1 # Start from 1
# Ensure required columns exist with default values
required_columns = {
"シート名": "",
"行": "",
"科目": "",
"中科目": "",
"分類": "",
"名称": "",
"単位": "",
"摘要": "",
"備考": "",
"出力_科目": "",
"出力_中科目": "",
"出力_項目名": "",
"出力_標準単位": "",
"出力_集計用単位": "",
"出力_確率度": 0.0,
}
for col, default_value in required_columns.items():
if col not in df_output_data.columns:
df_output_data[col] = default_value
# Map output columns to match Excel structure
# 出力_中科目 mapping - use the standard sub-subject from sub-subject mapper
if "出力_基準中科目" in df_output_data.columns:
df_output_data["出力_中科目"] = df_output_data["出力_基準中科目"]
elif "標準中科目" in df_output_data.columns:
df_output_data["出力_中科目"] = df_output_data["標準中科目"]
# 出力_項目名 mapping - use the final item name from name and abstract mapper
if (
"出力_項目名" in df_output_data.columns
and not df_output_data["出力_項目名"].isna().all()
):
# Keep existing 出力_項目名 if it exists and has values
pass
elif "出力_標準名称" in df_output_data.columns:
df_output_data["出力_項目名"] = df_output_data["出力_標準名称"]
elif "出力_基準名称" in df_output_data.columns:
df_output_data["出力_項目名"] = df_output_data["出力_基準名称"]
# 出力_標準単位 mapping - use unit mapper result
if "出力_標準単位" in df_output_data.columns:
df_output_data["出力_標準単位"] = df_output_data["出力_標準単位"]
# 出力_集計用単位 mapping - use unit mapper result
if "出力_集計用単位" in df_output_data.columns:
df_output_data["出力_集計用単位"] = df_output_data["出力_集計用単位"]
# 出力_確率度 mapping - use the name similarity as main probability
if "出力_名称類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_名称類似度"]
elif "出力_中科目類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_中科目類似度"]
elif "出力_摘要類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_摘要類似度"]
elif "出力_単位類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_単位類似度"]
else:
df_output_data["出力_確率度"] = 0.0
# Fill NaN values and ensure all output columns have proper values
df_output_data = df_output_data.fillna("")
# Debug: Print available columns to see what we have
print(f"Available columns after processing: {list(df_output_data.columns)}")
# Final check and fallback for missing output columns
if (
"出力_中科目" not in df_output_data.columns
or df_output_data["出力_中科目"].eq("").all()
):
df_output_data["出力_中科目"] = df_output_data.get("中科目", "")
if (
"出力_項目名" not in df_output_data.columns
or df_output_data["出力_項目名"].eq("").all()
):
df_output_data["出力_項目名"] = df_output_data.get("名称", "")
if (
"出力_単位" not in df_output_data.columns
or df_output_data["出力_単位"].eq("").all()
):
df_output_data["出力_単位"] = df_output_data.get("単位", "")
if "出力_確率度" not in df_output_data.columns:
df_output_data["出力_確率度"] = 0 # Default confidence score
# Define output columns in exact order as shown in Excel
output_columns = [
"ID",
"シート名",
"行",
"科目",
"中科目",
"分類",
"名称",
"単位",
"摘要",
"備考",
"出力_科目",
"出力_中科目",
"出力_項目名",
"出力_確率度",
"出力_標準単位",
"出力_集計用単位",
]
# Save with utf_8_sig encoding for Japanese Excel compatibility
df_output_data[output_columns].to_csv(
output_file_path, index=False, encoding="utf_8_sig"
)
# Save all caches
sentence_service.save_all_caches()
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
return FileResponse(
path=output_file_path,
filename=f"output_{Path(file.filename).stem}.csv",
media_type="text/csv",
headers={
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
"Content-Type": "application/x-www-form-urlencoded",
},
)
except Exception as e:
print(f"Error processing file: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/embeddings")
async def create_embeddings(
request: EmbeddingRequest,
current_user=Depends(get_current_user_from_token),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Create embeddings for a list of input sentences (requires authentication)
"""
try:
start_time = time.time()
embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(
request.sentences
)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
# Convert numpy array to list for JSON serialization
embeddings_list = embeddings.tolist()
return {"embeddings": embeddings_list}
except Exception as e:
print(f"Error creating embeddings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/predict-raw", response_model=PredictRawResponse)
async def predict_raw(
request: PredictRawRequest,
current_user=Depends(get_current_user_from_token),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Process raw input records and return standardized names (requires authentication)
"""
try:
# Convert input records to DataFrame
records_dict = {
"科目": [],
"中科目": [],
"分類": [],
"名称": [],
"単位": [],
"摘要": [],
"備考": [],
"シート名": [], # Required by BaseNameData but not used
"行": [], # Required by BaseNameData but not used
}
for record in request.records:
records_dict["科目"].append(record.subject)
records_dict["中科目"].append(record.sub_subject)
records_dict["分類"].append(record.name_category)
records_dict["名称"].append(record.name)
records_dict["単位"].append("") # Default empty
records_dict["摘要"].append(record.abstract or "")
records_dict["備考"].append(record.memo or "")
records_dict["シート名"].append("") # Placeholder
records_dict["行"].append("") # Placeholder
df_input_data = pd.DataFrame(records_dict)
# Process data similar to the main predict function
try:
# Subject mapping
if sentence_service.df_subject_map_data is not None:
subject_similarity_mapper = SubjectSimilarityMapper(
cached_embedding_helper=sentence_service.subject_cached_embedding_helper,
df_map_data=sentence_service.df_subject_map_data,
)
list_input_subject = df_input_data["科目"].unique()
df_subject_data = pd.DataFrame({"科目": list_input_subject})
subject_similarity_mapper.predict_input(df_input_data=df_subject_data)
output_subject_map = dict(
zip(df_subject_data["科目"], df_subject_data["出力_科目"])
)
df_input_data["標準科目"] = df_input_data["科目"].map(
output_subject_map
)
df_input_data["出力_科目"] = df_input_data["科目"].map(
output_subject_map
)
else:
df_input_data["標準科目"] = df_input_data["科目"]
df_input_data["出力_科目"] = df_input_data["科目"]
except Exception as e:
print(f"Error processing SubjectSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Name mapping (simplified for raw predict)
if sentence_service.df_name_map_data is not None:
name_sentence_mapper = NameSimilarityMapper(
cached_embedding_helper=sentence_service.name_cached_embedding_helper,
df_map_data=sentence_service.df_name_map_data,
)
name_sentence_mapper.predict_input(df_input_data=df_input_data)
except Exception as e:
print(f"Error processing NameSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Unit mapping
if sentence_service.df_unit_map_data is not None:
unit_mapper = UnitSimilarityMapper(
cached_embedding_helper=sentence_service.unit_cached_embedding_helper,
df_map_data=sentence_service.df_unit_map_data,
)
unit_mapper.predict_input(df_input_data=df_input_data)
except Exception as e:
print(f"Error processing UnitSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Ensure required columns exist
for col in [
"確定",
"出力_標準名称",
"出力_名称類似度",
"出力_標準単位",
"出力_単位類似度",
]:
if col not in df_input_data.columns:
if col in ["出力_名称類似度", "出力_単位類似度"]:
df_input_data[col] = 0.0
else:
df_input_data[col] = ""
# Convert results to response format
results = []
for _, row in df_input_data.iterrows():
result = PredictResult(
subject=row["科目"],
sub_subject=row["中科目"],
name_category=row["分類"],
name=row["名称"],
abstract=row["摘要"],
memo=row["備考"],
confirmed=row.get("確定", ""),
standard_subject=row.get("出力_科目", row["科目"]),
standard_name=row.get("出力_標準名称", ""),
similarity_score=float(row.get("出力_名称類似度", 0.0)),
)
results.append(result)
# Save all caches
sentence_service.save_all_caches()
return PredictRawResponse(results=results)
except Exception as e:
print(f"Error processing records: {e}")
raise HTTPException(status_code=500, detail=str(e))