Spaces:
Sleeping
Sleeping
""" | |
RAG ๊ฒ์ ์ฑ๋ด ์น ์ ํ๋ฆฌ์ผ์ด์ - API ๋ผ์ฐํธ ์ ์ | |
""" | |
import os | |
import json | |
import logging | |
import tempfile | |
import requests | |
from flask import request, jsonify, render_template, send_from_directory, session, redirect, url_for | |
from datetime import datetime | |
from werkzeug.utils import secure_filename | |
# ๋ก๊ฑฐ ๊ฐ์ ธ์ค๊ธฐ | |
logger = logging.getLogger(__name__) | |
def register_routes(app, login_required, llm_interface, retriever, stt_client, DocumentProcessor, base_retriever, app_ready, ADMIN_USERNAME, ADMIN_PASSWORD, DEVICE_SERVER_URL): | |
"""Flask ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ผ์ฐํธ ๋ฑ๋ก""" | |
# ํฌํผ ํจ์ | |
def allowed_audio_file(filename): | |
"""ํ์ผ์ด ํ์ฉ๋ ์ค๋์ค ํ์ฅ์๋ฅผ ๊ฐ์ง๋์ง ํ์ธ""" | |
ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a'} | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_AUDIO_EXTENSIONS | |
def allowed_doc_file(filename): | |
"""ํ์ผ์ด ํ์ฉ๋ ๋ฌธ์ ํ์ฅ์๋ฅผ ๊ฐ์ง๋์ง ํ์ธ""" | |
ALLOWED_DOC_EXTENSIONS = {'txt', 'md', 'pdf', 'docx', 'csv'} | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_DOC_EXTENSIONS | |
# ์๋ฒ ๋ฉ ์ ์ฅ ํจ์ | |
def save_embeddings(base_retriever, file_path): | |
"""์๋ฒ ๋ฉ ๋ฐ์ดํฐ๋ฅผ ์์ถํ์ฌ ํ์ผ์ ์ ์ฅ""" | |
import pickle | |
import gzip | |
try: | |
# ์ ์ฅ ๋๋ ํ ๋ฆฌ๊ฐ ์์ผ๋ฉด ์์ฑ | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
# ํ์์คํฌํ ์ถ๊ฐ | |
save_data = { | |
'timestamp': datetime.now().isoformat(), | |
'retriever': base_retriever | |
} | |
# ์์ถํ์ฌ ์ ์ฅ (์ฉ๋ ์ค์ด๊ธฐ) | |
with gzip.open(file_path, 'wb') as f: | |
pickle.dump(save_data, f) | |
logger.info(f"์๋ฒ ๋ฉ ๋ฐ์ดํฐ๋ฅผ {file_path}์ ์์ถํ์ฌ ์ ์ฅํ์ต๋๋ค.") | |
return True | |
except Exception as e: | |
logger.error(f"์๋ฒ ๋ฉ ์ ์ฅ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
return False | |
def login(): | |
error = None | |
next_url = request.args.get('next') | |
logger.info(f"-------------- ๋ก๊ทธ์ธ ํ์ด์ง ์ ์ (Next: {next_url}) --------------") | |
logger.info(f"Method: {request.method}") | |
if request.method == 'POST': | |
logger.info("๋ก๊ทธ์ธ ์๋ ๋ฐ์") | |
username = request.form.get('username', '') | |
password = request.form.get('password', '') | |
logger.info(f"์ ๋ ฅ๋ ์ฌ์ฉ์๋ช : {username}") | |
logger.info(f"๋น๋ฐ๋ฒํธ ์ ๋ ฅ ์ฌ๋ถ: {len(password) > 0}") | |
# ํ๊ฒฝ ๋ณ์ ๋๋ ๊ธฐ๋ณธ๊ฐ๊ณผ ๋น๊ต | |
valid_username = ADMIN_USERNAME | |
valid_password = ADMIN_PASSWORD | |
logger.info(f"๊ฒ์ฆ์ฉ ์ฌ์ฉ์๋ช : {valid_username}") | |
logger.info(f"๊ฒ์ฆ์ฉ ๋น๋ฐ๋ฒํธ ์กด์ฌ ์ฌ๋ถ: {valid_password is not None and len(valid_password) > 0}") | |
if username == valid_username and password == valid_password: | |
logger.info(f"๋ก๊ทธ์ธ ์ฑ๊ณต: {username}") | |
# ์ธ์ ์ค์ ์ ํ์ฌ ์ธ์ ์ํ ๋ก๊น | |
logger.debug(f"์ธ์ ์ค์ ์ : {session}") | |
# ์ธ์ ์ ๋ก๊ทธ์ธ ์ ๋ณด ์ ์ฅ | |
session.permanent = True | |
session['logged_in'] = True | |
session['username'] = username | |
session.modified = True | |
logger.info(f"์ธ์ ์ค์ ํ: {session}") | |
logger.info("์ธ์ ์ค์ ์๋ฃ, ๋ฆฌ๋๋ ์ ์๋") | |
# ๋ก๊ทธ์ธ ์ฑ๊ณต ํ ๋ฆฌ๋๋ ์ | |
redirect_to = next_url or url_for('index') | |
logger.info(f"๋ฆฌ๋๋ ์ ๋์: {redirect_to}") | |
response = redirect(redirect_to) | |
return response | |
else: | |
logger.warning("๋ก๊ทธ์ธ ์คํจ: ์์ด๋ ๋๋ ๋น๋ฐ๋ฒํธ ๋ถ์ผ์น") | |
if username != valid_username: logger.warning("์ฌ์ฉ์๋ช ๋ถ์ผ์น") | |
if password != valid_password: logger.warning("๋น๋ฐ๋ฒํธ ๋ถ์ผ์น") | |
error = '์์ด๋ ๋๋ ๋น๋ฐ๋ฒํธ๊ฐ ์ฌ๋ฐ๋ฅด์ง ์์ต๋๋ค.' | |
else: | |
logger.info("๋ก๊ทธ์ธ ํ์ด์ง GET ์์ฒญ") | |
if 'logged_in' in session: | |
logger.info("์ด๋ฏธ ๋ก๊ทธ์ธ๋ ์ฌ์ฉ์, ๋ฉ์ธ ํ์ด์ง๋ก ๋ฆฌ๋๋ ์ ") | |
return redirect(url_for('index')) | |
logger.info("---------- ๋ก๊ทธ์ธ ํ์ด์ง ๋ ๋๋ง ----------") | |
return render_template('login.html', error=error, next=next_url) | |
def logout(): | |
logger.info("-------------- ๋ก๊ทธ์์ ์์ฒญ --------------") | |
logger.info(f"๋ก๊ทธ์์ ์ ์ธ์ ์ํ: {session}") | |
if 'logged_in' in session: | |
username = session.get('username', 'unknown') | |
logger.info(f"์ฌ์ฉ์ {username} ๋ก๊ทธ์์ ์ฒ๋ฆฌ ์์") | |
session.pop('logged_in', None) | |
session.pop('username', None) | |
session.modified = True | |
logger.info(f"์ธ์ ์ ๋ณด ์ญ์ ์๋ฃ. ํ์ฌ ์ธ์ : {session}") | |
else: | |
logger.warning("๋ก๊ทธ์ธ๋์ง ์์ ์ํ์์ ๋ก๊ทธ์์ ์๋") | |
logger.info("๋ก๊ทธ์ธ ํ์ด์ง๋ก ๋ฆฌ๋๋ ์ ") | |
response = redirect(url_for('login')) | |
return response | |
def index(): | |
"""๋ฉ์ธ ํ์ด์ง""" | |
nonlocal app_ready | |
# ์ฑ ์ค๋น ์ํ ํ์ธ - 30์ด ์ด์ ์ง๋ฌ์ผ๋ฉด ๊ฐ์ ๋ก ready ์ํ๋ก ๋ณ๊ฒฝ | |
current_time = datetime.now() | |
start_time = datetime.fromtimestamp(os.path.getmtime(__file__)) | |
time_diff = (current_time - start_time).total_seconds() | |
if not app_ready and time_diff > 30: | |
logger.warning(f"์ฑ์ด 30์ด ์ด์ ์ด๊ธฐํ ์ค ์ํ์ ๋๋ค. ๊ฐ์ ๋ก ready ์ํ๋ก ๋ณ๊ฒฝํฉ๋๋ค.") | |
app_ready = True | |
if not app_ready: | |
logger.info("์ฑ์ด ์์ง ์ค๋น๋์ง ์์ ๋ก๋ฉ ํ์ด์ง ํ์") | |
return render_template('loading.html'), 503 # ์๋น์ค ์ค๋น ์๋จ ์ํ ์ฝ๋ | |
logger.info("๋ฉ์ธ ํ์ด์ง ์์ฒญ") | |
return render_template('index.html') | |
def app_status(): | |
"""์ฑ ์ด๊ธฐํ ์ํ ํ์ธ API""" | |
logger.info(f"์ฑ ์ํ ํ์ธ ์์ฒญ: {'Ready' if app_ready else 'Not Ready'}") | |
return jsonify({"ready": app_ready}) | |
def llm_api(): | |
"""์ฌ์ฉ ๊ฐ๋ฅํ LLM ๋ชฉ๋ก ๋ฐ ์ ํ API""" | |
if not app_ready: | |
return jsonify({"error": "์ฑ์ด ์์ง ์ด๊ธฐํ ์ค์ ๋๋ค. ์ ์ ํ ๋ค์ ์๋ํด์ฃผ์ธ์."}), 503 | |
if request.method == 'GET': | |
logger.info("LLM ๋ชฉ๋ก ์์ฒญ") | |
try: | |
current_details = llm_interface.get_current_llm_details() if hasattr(llm_interface, 'get_current_llm_details') else {"id": "unknown", "name": "Unknown"} | |
supported_llms_dict = llm_interface.SUPPORTED_LLMS if hasattr(llm_interface, 'SUPPORTED_LLMS') else {} | |
supported_list = [{ | |
"name": name, "id": id, "current": id == current_details.get("id") | |
} for name, id in supported_llms_dict.items()] | |
return jsonify({ | |
"supported_llms": supported_list, | |
"current_llm": current_details | |
}) | |
except Exception as e: | |
logger.error(f"LLM ์ ๋ณด ์กฐํ ์ค๋ฅ: {e}") | |
return jsonify({"error": "LLM ์ ๋ณด ์กฐํ ์ค ์ค๋ฅ ๋ฐ์"}), 500 | |
elif request.method == 'POST': | |
data = request.get_json() | |
if not data or 'llm_id' not in data: | |
return jsonify({"error": "LLM ID๊ฐ ์ ๊ณต๋์ง ์์์ต๋๋ค."}), 400 | |
llm_id = data['llm_id'] | |
logger.info(f"LLM ๋ณ๊ฒฝ ์์ฒญ: {llm_id}") | |
try: | |
if not hasattr(llm_interface, 'set_llm') or not hasattr(llm_interface, 'llm_clients'): | |
raise NotImplementedError("LLM ์ธํฐํ์ด์ค์ ํ์ํ ๋ฉ์๋/์์ฑ ์์") | |
if llm_id not in llm_interface.llm_clients: | |
return jsonify({"error": f"์ง์๋์ง ์๋ LLM ID: {llm_id}"}), 400 | |
success = llm_interface.set_llm(llm_id) | |
if success: | |
new_details = llm_interface.get_current_llm_details() | |
logger.info(f"LLM์ด '{new_details.get('name', llm_id)}'๋ก ๋ณ๊ฒฝ๋์์ต๋๋ค.") | |
return jsonify({ | |
"success": True, | |
"message": f"LLM์ด '{new_details.get('name', llm_id)}'๋ก ๋ณ๊ฒฝ๋์์ต๋๋ค.", | |
"current_llm": new_details | |
}) | |
else: | |
logger.error(f"LLM ๋ณ๊ฒฝ ์คํจ (ID: {llm_id})") | |
return jsonify({"error": "LLM ๋ณ๊ฒฝ ์ค ๋ด๋ถ ์ค๋ฅ ๋ฐ์"}), 500 | |
except Exception as e: | |
logger.error(f"LLM ๋ณ๊ฒฝ ์ฒ๋ฆฌ ์ค ์ค๋ฅ: {e}", exc_info=True) | |
return jsonify({"error": f"LLM ๋ณ๊ฒฝ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}"}), 500 | |
def chat(): | |
"""ํ ์คํธ ๊ธฐ๋ฐ ์ฑ๋ด API""" | |
if not app_ready or retriever is None: | |
return jsonify({"error": "์ฑ/๊ฒ์๊ธฐ๊ฐ ์์ง ์ด๊ธฐํ ์ค์ ๋๋ค. ์ ์ ํ ๋ค์ ์๋ํด์ฃผ์ธ์."}), 503 | |
try: | |
data = request.get_json() | |
if not data or 'query' not in data: | |
return jsonify({"error": "์ฟผ๋ฆฌ๊ฐ ์ ๊ณต๋์ง ์์์ต๋๋ค."}), 400 | |
query = data['query'] | |
logger.info(f"ํ ์คํธ ์ฟผ๋ฆฌ ์์ : {query[:100]}...") | |
# RAG ๊ฒ์ ์ํ | |
if not hasattr(retriever, 'search'): | |
raise NotImplementedError("Retriever์ search ๋ฉ์๋๊ฐ ์์ต๋๋ค.") | |
search_results = retriever.search(query, top_k=5, first_stage_k=6) | |
# ์ปจํ ์คํธ ์ค๋น | |
if not hasattr(DocumentProcessor, 'prepare_rag_context'): | |
raise NotImplementedError("DocumentProcessor์ prepare_rag_context ๋ฉ์๋๊ฐ ์์ต๋๋ค.") | |
context = DocumentProcessor.prepare_rag_context(search_results, field="text") | |
if not context: | |
logger.warning("๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์์ด ์ปจํ ์คํธ๋ฅผ ์์ฑํ์ง ๋ชปํจ.") | |
# LLM์ ์ง์ | |
llm_id = data.get('llm_id', None) | |
if not hasattr(llm_interface, 'rag_generate'): | |
raise NotImplementedError("LLMInterface์ rag_generate ๋ฉ์๋๊ฐ ์์ต๋๋ค.") | |
if not context: | |
answer = "์ฃ์กํฉ๋๋ค. ๊ด๋ จ ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
logger.info("์ปจํ ์คํธ ์์ด ๊ธฐ๋ณธ ์๋ต ์์ฑ") | |
else: | |
answer = llm_interface.rag_generate(query, context, llm_id=llm_id) | |
logger.info(f"LLM ์๋ต ์์ฑ ์๋ฃ (๊ธธ์ด: {len(answer)})") | |
# ์์ค ์ ๋ณด ์ถ์ถ (CSV ID ์ถ์ถ ๋ก์ง ํฌํจ) | |
sources = [] | |
if search_results: | |
for result in search_results: | |
if not isinstance(result, dict): | |
logger.warning(f"์์์น ๋ชปํ ๊ฒ์ ๊ฒฐ๊ณผ ํ์: {type(result)}") | |
continue | |
if "source" in result: | |
source_info = { | |
"source": result.get("source", "Unknown"), | |
"score": result.get("rerank_score", result.get("score", 0)) | |
} | |
# CSV ํ์ผ ํน์ ์ฒ๋ฆฌ | |
if "text" in result and result.get("filetype") == "csv": | |
try: | |
text_lines = result["text"].strip().split('\n') | |
if text_lines: | |
first_line = text_lines[0].strip() | |
if ',' in first_line: | |
first_column = first_line.split(',')[0].strip() | |
source_info["id"] = first_column | |
logger.debug(f"CSV ์์ค ID ์ถ์ถ: {first_column} from {source_info['source']}") | |
except Exception as e: | |
logger.warning(f"CSV ์์ค ID ์ถ์ถ ์คํจ ({result.get('source')}): {e}") | |
sources.append(source_info) | |
# ์ต์ข ์๋ต | |
response_data = { | |
"answer": answer, | |
"sources": sources, | |
"llm": llm_interface.get_current_llm_details() if hasattr(llm_interface, 'get_current_llm_details') else {} | |
} | |
return jsonify(response_data) | |
except Exception as e: | |
logger.error(f"์ฑํ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {e}", exc_info=True) | |
return jsonify({"error": f"์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}"}), 500 | |