Spaces:
Sleeping
Sleeping
# ์์ ์ ์ธ AI ์นดํผ๋ผ์ดํฐ - ์๋ฒ ๋ฉ ๊ธฐ๋ฐ RAG ์์คํ | |
# Hugging Face Spaces ํ๊ฒฝ ์ต์ ํ ๋ฒ์ | |
import streamlit as st | |
import pandas as pd | |
import numpy # ์ ์ญ์ ์ผ๋ก numpy๋ฅผ ๋จผ์ ์ํฌํธํด๋ด ๋๋ค. | |
import pickle | |
import google.generativeai as genai | |
import time | |
import json | |
import os | |
import sys # ๋๋ฒ๊น ์ฉ sys ๋ชจ๋ ์ํฌํธ | |
from datetime import datetime | |
# ํ๊ฒฝ ์ค์ (๊ถํ ๋ฌธ์ ํด๊ฒฐ) | |
os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' | |
# ์บ์ ๊ฒฝ๋ก๋ฅผ /tmp ๋ก ์ค์ (Hugging Face Spaces์์ ๊ถ์ฅ๋๋ ์ฐ๊ธฐ ๊ฐ๋ฅ ๊ฒฝ๋ก) | |
TMP_DIR = "/tmp" | |
TRANSFORMERS_CACHE_DIR = os.path.join(TMP_DIR, '.cache', 'transformers') | |
SENTENCE_TRANSFORMERS_HOME_DIR = os.path.join(TMP_DIR, '.cache', 'sentence_transformers') | |
os.environ['TRANSFORMERS_CACHE'] = TRANSFORMERS_CACHE_DIR | |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = SENTENCE_TRANSFORMERS_HOME_DIR | |
# ์บ์ ๋๋ ํ ๋ฆฌ ์์ฑ (์กด์ฌํ์ง ์์ผ๋ฉด) - /tmp ์๋๋ ์ผ๋ฐ์ ์ผ๋ก ์์ฑ ๊ฐ๋ฅ | |
try: | |
os.makedirs(TRANSFORMERS_CACHE_DIR, exist_ok=True) | |
os.makedirs(SENTENCE_TRANSFORMERS_HOME_DIR, exist_ok=True) | |
except PermissionError: | |
st.warning(f"โ ๏ธ ์บ์ ๋๋ ํ ๋ฆฌ ์์ฑ ๊ถํ ์์: {TRANSFORMERS_CACHE_DIR} ๋๋ {SENTENCE_TRANSFORMERS_HOME_DIR}. ๋ชจ๋ธ ๋ค์ด๋ก๋๊ฐ ๋๋ฆด ์ ์์ต๋๋ค.") | |
except Exception as e_mkdir: | |
st.warning(f"โ ๏ธ ์บ์ ๋๋ ํ ๋ฆฌ ์์ฑ ์ค ์ค๋ฅ: {e_mkdir}") | |
# ํ์ด์ง ์ค์ | |
st.set_page_config( | |
page_title="AI ์นดํผ๋ผ์ดํฐ | RAG ๊ธฐ๋ฐ ๊ด๊ณ ์นดํผ ์์ฑ", | |
page_icon="โจ", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# ์ ๋ชฉ ๋ฐ ์ค๋ช | |
st.title("โจ AI ์นดํผ๋ผ์ดํฐ") | |
st.markdown("### ๐ฏ 37,671๊ฐ ์ค์ ๊ด๊ณ ์นดํผ ๋ฐ์ดํฐ ๊ธฐ๋ฐ RAG ์์คํ ") | |
st.markdown("---") | |
# --- ๋ฐํ์ ํ๊ฒฝ ๋๋ฒ๊น (์ ํ๋ฆฌ์ผ์ด์ ์ต์๋จ ๋๋ load_system ๋ฐ๋ก ์ ) --- | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("### โ๏ธ ๋ฐํ์ ํ๊ฒฝ ์ ๋ณด (๋๋ฒ๊น ์ฉ)") | |
st.sidebar.text(f"Py Exec: {sys.executable}") | |
st.sidebar.text(f"Py Ver: {sys.version.split()[0]}") # ๊ฐ๋ตํ๊ฒ ๋ฒ์ ๋ง | |
# st.sidebar.text(f"sys.path: {sys.path}") # ๋๋ฌด ๊ธธ์ด์ ์ผ๋จ ์ฃผ์ | |
st.sidebar.text(f"PYTHONPATH: {os.environ.get('PYTHONPATH', 'Not Set')}") | |
try: | |
# numpy๋ฅผ ์ฌ๊ธฐ์ ๋ค์ ์ํฌํธํ๊ณ ์ฌ์ฉ | |
import numpy as np_runtime_check | |
st.sidebar.text(f"NumPy Ver (Runtime): {np_runtime_check.__version__}") | |
# ํต์ฌ ๋ชจ๋ ์ํฌํธ ์๋ | |
import numpy.core._multiarray_umath | |
st.sidebar.markdown("โ NumPy core modules imported (Runtime)") | |
except Exception as e: | |
st.sidebar.error(f"โ NumPy import error (Runtime): {e}") | |
st.sidebar.markdown("---") | |
# --- ๋๋ฒ๊น ์ฝ๋ ๋ --- | |
# ์ฌ์ด๋๋ฐ ์ค์ | |
st.sidebar.header("๐๏ธ ์นดํผ ์์ฑ ์ค์ ") | |
# API ํค ์ ๋ ฅ (ํ๊ฒฝ๋ณ์ ์ฐ์ ์ฌ์ฉ) | |
default_api_key = os.getenv("GEMINI_API_KEY", "") | |
api_key = st.sidebar.text_input( | |
"๐ Gemini API ํค", | |
value=default_api_key, | |
type="password", | |
help="ํ๊ฒฝ๋ณ์์ GEMINI_API_KEY๋ก ์ค์ ํ๋ฉด ์๋ ์ ๋ ฅ๋ฉ๋๋ค" | |
) | |
if not api_key: | |
st.warning("โ ๏ธ Gemini API ํค๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์") | |
st.info("๐ก Settings โ Repository secrets์์ GEMINI_API_KEY๋ฅผ ์ค์ ํ์ธ์") | |
st.stop() | |
# ์์คํ ์ด๊ธฐํ (์บ์ฑ) - ์๋ฒ ๋ฉ ํ์! | |
def load_system(): | |
"""์์คํ ์ปดํฌ๋ํธ ๋ก๋ฉ - ์๋ฒ ๋ฉ ๊ธฐ๋ฐ RAG ์์คํ """ | |
# --- ํจ์ ์์ ์ ๋๋ฒ๊น ์ ๋ณด ์ถ๊ฐ --- | |
st.write("--- load_system() ์์ ---") | |
st.write(f"Python Executable (load_system): {sys.executable}") | |
st.write(f"Python Version (load_system): {sys.version}") | |
# st.write(f"sys.path (load_system): {sys.path}") # ๋๋ฌด ๊ธธ์ด์ ์ฃผ์ | |
st.write(f"PYTHONPATH (load_system): {os.environ.get('PYTHONPATH')}") | |
try: | |
import numpy as np_load_system_check # ์ ๋ณ์นญ ์ฌ์ฉ | |
st.write(f"NumPy version (load_system start): {np_load_system_check.__version__}") | |
import numpy.core._multiarray_umath | |
st.write("load_system start: Successfully imported numpy.core._multiarray_umath") | |
except Exception as e: | |
st.write(f"load_system start: Error importing NumPy parts: {e}") | |
# --- ๋๋ฒ๊น ์ ๋ณด ๋ --- | |
progress_container = st.container() | |
with progress_container: | |
# ์ ์ฒด ์งํ๋ฅ | |
total_progress = st.progress(0) | |
status_text = st.empty() | |
# 1๋จ๊ณ: API ์ค์ (10%) | |
status_text.text("๐ Gemini API ์ด๊ธฐํ ์ค...") | |
try: | |
genai.configure(api_key=api_key) | |
model_llm = genai.GenerativeModel('gemini-1.5-flash') # ๋ชจ๋ธ ์ด๋ฆ ํ์ธ (์ด์ ์ gemini-2.0-flash) | |
total_progress.progress(10) | |
st.success("โ Gemini API ์ค์ ์๋ฃ") | |
except Exception as e: | |
st.error(f"โ Gemini API ์ค์ ์คํจ: {e}") | |
return None, None, None, None | |
# 2๋จ๊ณ: ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ (40%) | |
status_text.text("๐ค ํ๊ตญ์ด ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ ์ค... (1-2๋ถ ์์)") | |
embedding_model_instance = None # ๋ณ์๋ช ๋ณ๊ฒฝ | |
try: | |
# sentence-transformers ์ํฌํธ๋ฅผ ํจ์ ๋ด์์ ์ ์ง | |
from sentence_transformers import SentenceTransformer | |
# from sklearn.metrics.pairwise import cosine_similarity # ์ฌ๊ธฐ์๋ ์์ง ํ์ ์์ | |
embedding_model_instance = SentenceTransformer('jhgan/ko-sbert-nli', | |
cache_folder=SENTENCE_TRANSFORMERS_HOME_DIR) # ์์ ๋ ์บ์ ๊ฒฝ๋ก ์ฌ์ฉ | |
total_progress.progress(40) | |
st.success("โ ํ๊ตญ์ด ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ") | |
except Exception as e: | |
st.error(f"โ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
st.error("๐จ ์๋ฒ ๋ฉ ๋ชจ๋ธ ์์ด๋ RAG ์์คํ ์ด ์๋ํ ์ ์์ต๋๋ค!") | |
return None, None, None, None | |
# 3๋จ๊ณ: ๋ฐ์ดํฐ ๋ก๋ (60%) | |
status_text.text("๐ ์นดํผ ๋ฐ์ดํฐ๋ฒ ์ด์ค ๋ก๋ฉ ์ค...") | |
df_data = None # ๋ณ์๋ช ๋ณ๊ฒฝ | |
try: | |
df_data = pd.read_excel('๊ด๊ณ ์นดํผ๋ฐ์ดํฐ_๋ธ๋๋์ถ์ถ์๋ฃ.xlsx') | |
total_progress.progress(60) | |
st.success(f"โ ๋ฐ์ดํฐ ๋ก๋ฉ ์๋ฃ: {len(df_data):,}๊ฐ ์นดํผ") | |
except Exception as e: | |
st.error(f"โ ๋ฐ์ดํฐ ๋ก๋ฉ ์คํจ: {e}") | |
return None, None, None, None | |
# 4๋จ๊ณ: ์๋ฒ ๋ฉ ๋ฐ์ดํฐ ๋ก๋ (90%) - ์ด๊ฒ ํต์ฌ! | |
status_text.text("๐ ๋ฒกํฐ ์๋ฒ ๋ฉ ๋ก๋ฉ ์ค... (RAG ์์คํ ํต์ฌ)") | |
embeddings_array = None # ๋ณ์๋ช ๋ณ๊ฒฝ | |
try: | |
# --- pickle.load() ์ง์ NumPy ๋๋ฒ๊น --- | |
import numpy as np_pickle_check # ์ ๋ณ์นญ ์ฌ์ฉ | |
st.write(f"[DEBUG] NumPy version just before pickle.load: {np_pickle_check.__version__}") | |
import numpy.core._multiarray_umath | |
st.write("[DEBUG] Successfully imported numpy.core._multiarray_umath before pickle.load") | |
# --- ๋๋ฒ๊น ๋ --- | |
with open('copy_embeddings.pkl', 'rb') as f: | |
embeddings_data = pickle.load(f) | |
embeddings_array = embeddings_data['embeddings'] | |
total_progress.progress(90) | |
st.success(f"โ ์๋ฒ ๋ฉ ๋ก๋ฉ ์๋ฃ: {embeddings_array.shape[0]:,}๊ฐ ร {embeddings_array.shape[1]}์ฐจ์") | |
except ModuleNotFoundError as mnfe: # ModuleNotFoundError๋ฅผ ํน์ ํด์ ์ก๊ธฐ | |
st.error(f"โ ์๋ฒ ๋ฉ ๋ก๋ฉ ์คํจ (ModuleNotFoundError): {mnfe}") | |
st.error(f"๐จ ํด๋น ๋ชจ๋์ ์ฐพ์ ์ ์์ต๋๋ค. sys.path: {sys.path}") | |
st.error("๐จ ์๋ฒ ๋ฉ ์์ด๋ ์๋ฏธ์ ๊ฒ์์ด ๋ถ๊ฐ๋ฅํฉ๋๋ค!") | |
# ์ถ๊ฐ ๋๋ฒ๊น : ํ์ฌ ๋ก๋๋ numpy ๊ฐ์ฒด ์ํ | |
try: | |
import numpy as np_final_check | |
st.error(f"[DEBUG] NumPy object at failure: {np_final_check}") | |
st.error(f"[DEBUG] NumPy __file__ at failure: {np_final_check.__file__}") | |
except Exception as e_np_final: | |
st.error(f"[DEBUG] Could not even import numpy at failure: {e_np_final}") | |
return None, None, None, None | |
except Exception as e: | |
st.error(f"โ ์๋ฒ ๋ฉ ๋ก๋ฉ ์คํจ (์ผ๋ฐ ์ค๋ฅ): {e}") | |
st.error("๐จ ์๋ฒ ๋ฉ ์์ด๋ ์๋ฏธ์ ๊ฒ์์ด ๋ถ๊ฐ๋ฅํฉ๋๋ค!") | |
return None, None, None, None | |
# 5๋จ๊ณ: ์ต์ข ๊ฒ์ฆ (100%) | |
status_text.text("โจ ์์คํ ๊ฒ์ฆ ์ค...") | |
if model_llm and embedding_model_instance and df_data is not None and embeddings_array is not None: | |
total_progress.progress(100) | |
status_text.text("๐ RAG ์์คํ ๋ก๋ฉ ์๋ฃ!") | |
success_col1, success_col2, success_col3 = st.columns(3) | |
with success_col1: | |
st.metric("์นดํผ ๋ฐ์ดํฐ", f"{len(df_data):,}๊ฐ") | |
with success_col2: | |
st.metric("์๋ฒ ๋ฉ ์ฐจ์", f"{embeddings_array.shape[1]}D") | |
with success_col3: | |
st.metric("๊ฒ์ ์์ง", "Korean SBERT") | |
time.sleep(1) | |
total_progress.empty() | |
status_text.empty() | |
# ์ ์ญ ๋ณ์๋ช ๊ณผ์ ์ถฉ๋์ ํผํ๊ธฐ ์ํด ํจ์ ๋ด์์ ์ฌ์ฉํ ๋ณ์๋ช ์ผ๋ก ๋ฐํ | |
return model_llm, embedding_model_instance, df_data, embeddings_array | |
else: | |
st.error("โ ์์คํ ๋ก๋ฉ ์คํจ: ํ์ ๊ตฌ์ฑ์์ ๋๋ฝ") | |
return None, None, None, None | |
# ์์คํ ๋ก๋ฉ (๋ณ์๋ช ์ถฉ๋ ๋ฐฉ์ง๋ฅผ ์ํด ์๋ก์ด ์ด๋ฆ ์ฌ์ฉ) | |
loaded_model, loaded_embedding_model, loaded_df, loaded_embeddings = None, None, None, None | |
with st.spinner("๐ AI ์นดํผ๋ผ์ดํฐ ์์คํ ์ด๊ธฐํ ์ค..."): | |
loaded_model, loaded_embedding_model, loaded_df, loaded_embeddings = load_system() | |
if loaded_model is None or loaded_embedding_model is None or loaded_df is None or loaded_embeddings is None: | |
st.error("โ ์์คํ ์ ๋ก๋ฉํ ์ ์์ต๋๋ค. ํ์ด์ง๋ฅผ ์๋ก๊ณ ์นจํ๊ฑฐ๋ ๊ด๋ฆฌ์์๊ฒ ๋ฌธ์ํ์ธ์.") | |
st.stop() | |
# ์ฌ์ด๋๋ฐ ์ค์ (์์คํ ๋ก๋ฉ ์ฑ๊ณต ํ) | |
st.sidebar.success("๐ RAG ์์คํ ์ค๋น ์๋ฃ!") | |
# ์นดํ ๊ณ ๋ฆฌ ์ ํ | |
categories = ['์ ์ฒด'] + sorted(loaded_df['์นดํ ๊ณ ๋ฆฌ'].unique().tolist()) | |
selected_category = st.sidebar.selectbox( | |
"๐ ์นดํ ๊ณ ๋ฆฌ", | |
categories, | |
help="ํน์ ์นดํ ๊ณ ๋ฆฌ๋ก ๊ฒ์์ ์ ํํ ์ ์์ต๋๋ค" | |
) | |
# ํ๊ฒ ๊ณ ๊ฐ ์ค์ | |
target_audience = st.sidebar.selectbox( | |
"๐ฏ ํ๊ฒ ๊ณ ๊ฐ", | |
['20๋', '30๋', '์ผ๋ฐ', '10๋', '40๋', '50๋+', '๋จ์ฑ', '์ฌ์ฑ', '์ง์ฅ์ธ', 'ํ์', '์ฃผ๋ถ'], | |
help="ํ๊ฒ ๊ณ ๊ฐ์ ๋ง๋ ํค์ค๋งค๋๋ก ์นดํผ๋ฅผ ์์ฑํฉ๋๋ค" | |
) | |
# ๋ธ๋๋ ํค์ค๋งค๋ | |
brand_tone = st.sidebar.selectbox( | |
"๐จ ๋ธ๋๋ ํค", | |
['์ธ๋ จ๋', '์น๊ทผํ', '๊ณ ๊ธ์ค๋ฌ์ด', 'ํ๊ธฐ์ฐฌ', '์ ๋ขฐํ ์ ์๋', '์ ์', '๋ฐ๋ปํ', '์ ๋ฌธ์ ์ธ'], | |
help="์ํ๋ ๋ธ๋๋ ์ด๋ฏธ์ง๋ฅผ ์ ํํ์ธ์" | |
) | |
# ์ฐฝ์์ฑ ์์ค | |
creative_level = st.sidebar.select_slider( | |
"๐ง ์ฐฝ์์ฑ ์์ค", | |
options=['๋ณด์์ ', '๊ท ํ', '์ฐฝ์์ '], | |
value='๊ท ํ', | |
help="๋ณด์์ : ์์ ํ ํํ, ์ฐฝ์์ : ๋ ์ฐฝ์ ํํ" | |
) | |
# ๋ฉ์ธ ์ ๋ ฅ ์์ญ | |
st.markdown("## ๐ญ ์ด๋ค ์นดํผ๋ฅผ ๋ง๋ค๊ณ ์ถ์ผ์ ๊ฐ์?") | |
user_request = "" # ์ด๊ธฐํ | |
input_method = st.radio( | |
"์ ๋ ฅ ๋ฐฉ์ ์ ํ:", | |
["์ง์ ์ ๋ ฅ", "ํ ํ๋ฆฟ ์ ํ"], | |
horizontal=True, | |
key="input_method_radio" # ๊ณ ์ ํค ์ถ๊ฐ | |
) | |
if input_method == "์ง์ ์ ๋ ฅ": | |
user_request = st.text_area( | |
"์นดํผ ์์ฒญ์ ์์ธํ ์์ฑํด์ฃผ์ธ์:", | |
placeholder="์: 30๋ ์ง์ฅ ์ฌ์ฑ์ฉ ํ๋ฆฌ๋ฏธ์ ์คํจ์ผ์ด ์ ์ ํ ๋ฐ์นญ ์นดํผ", | |
height=100, | |
key="user_request_direct" # ๊ณ ์ ํค ์ถ๊ฐ | |
) | |
else: | |
templates = { | |
"์ ์ ํ ๋ฐ์นญ": "๋์ {์นดํ ๊ณ ๋ฆฌ} ์ ์ ํ ๋ฐ์นญ ์นดํผ", | |
"ํ ์ธ ์ด๋ฒคํธ": "{์นดํ ๊ณ ๋ฆฌ} ํ ์ธ ์ด๋ฒคํธ ํ๋ก๋ชจ์ ์นดํผ", | |
"๋ธ๋๋ ์ฌ๋ก๊ฑด": "{์นดํ ๊ณ ๋ฆฌ} ๋ธ๋๋์ ๋ํ ์ฌ๋ก๊ฑด", | |
"์ฑ/์๋น์ค ๋ฆฌ๋ด์ผ": "{์๋น์ค๋ช } ์ ๋ฒ์ ์ถ์ ์นดํผ", | |
"์์ฆ ํ์ ": "{์์ฆ} ํ์ {์นดํ ๊ณ ๋ฆฌ} ํน๋ณ ์๋์ ์นดํผ" | |
} | |
selected_template = st.selectbox("ํ ํ๋ฆฟ ์ ํ:", list(templates.keys()), key="template_selectbox") | |
template_category = "" | |
service_name = "" | |
season = "" | |
col1, col2 = st.columns(2) | |
with col1: | |
template_category = st.text_input("์ ํ/์๋น์ค:", value="", key="template_category_input") | |
with col2: | |
if selected_template == "์ฑ/์๋น์ค ๋ฆฌ๋ด์ผ": | |
service_name = st.text_input("์๋น์ค๋ช :", placeholder="์: ๋ฐฐ๋ฌ์ฑ, ๊ธ์ต์ฑ", key="template_service_name_input") | |
user_request = templates[selected_template].format(์๋น์ค๋ช =service_name) | |
elif selected_template == "์์ฆ ํ์ ": | |
season = st.selectbox("์์ฆ:", ["๋ด", "์ฌ๋ฆ", "๊ฐ์", "๊ฒจ์ธ", "ํฌ๋ฆฌ์ค๋ง์ค", "์ ๋ "], key="template_season_selectbox") | |
user_request = templates[selected_template].format(์์ฆ=season, ์นดํ ๊ณ ๋ฆฌ=template_category) | |
else: | |
user_request = templates[selected_template].format(์นดํ ๊ณ ๋ฆฌ=template_category) | |
st.text_area("์์ฑ๋ ์์ฒญ:", value=user_request, height=80, disabled=True, key="generated_request_template") | |
# ๊ณ ๊ธ ์ต์ | |
with st.expander("๐ง ๊ณ ๊ธ ์ต์ "): | |
col1_adv, col2_adv = st.columns(2) # ๋ณ์๋ช ๋ณ๊ฒฝ | |
with col1_adv: | |
num_concepts = st.slider("์์ฑํ ์ปจ์ ์:", 1, 5, 3, key="num_concepts_slider") | |
min_similarity = st.slider("์ต์ ์ ์ฌ๋:", 0.0, 1.0, 0.3, 0.1, key="min_similarity_slider") | |
with col2_adv: | |
show_references = st.checkbox("์ฐธ๊ณ ์นดํผ ๋ณด๊ธฐ", value=True, key="show_references_checkbox") | |
num_references = st.slider("์ฐธ๊ณ ์นดํผ ์:", 3, 10, 5, key="num_references_slider") | |
# RAG ์นดํผ ์์ฑ ํจ์ (์๋ฒ ๋ฉ ๊ธฐ๋ฐ ํ์!) | |
def generate_copy_with_rag(user_req, category_filter, target_aud, brand_tn, creative_lvl, num_con): # ๋ณ์๋ช ๋ณ๊ฒฝ | |
"""RAG ๊ธฐ๋ฐ ์นดํผ ์์ฑ - ์๋ฒ ๋ฉ ํ์ ์ฌ์ฉ""" | |
if not user_req.strip(): | |
st.error("โ ์นดํผ ์์ฒญ์ ์ ๋ ฅํด์ฃผ์ธ์") | |
return None | |
progress_bar = st.progress(0) | |
status_text_gen = st.empty() # ๋ณ์๋ช ๋ณ๊ฒฝ | |
status_text_gen.text("๐ ์๋ฏธ์ ๊ฒ์ ์ค... (RAG ํต์ฌ ๊ธฐ๋ฅ)") | |
progress_bar.progress(20) | |
try: | |
search_query = f"{user_req} {target_aud} ๊ด๊ณ ์นดํผ" | |
from sklearn.metrics.pairwise import cosine_similarity # generate_copy_with_rag ๋ด์์ ์ํฌํธ | |
query_embedding = loaded_embedding_model.encode([search_query]) # ๋ก๋๋ ๋ชจ๋ธ ์ฌ์ฉ | |
if category_filter != '์ ์ฒด': | |
filtered_df_gen = loaded_df[loaded_df['์นดํ ๊ณ ๋ฆฌ'] == category_filter].copy() # .copy() ์ถ๊ฐ | |
else: | |
filtered_df_gen = loaded_df.copy() # .copy() ์ถ๊ฐ | |
progress_bar.progress(40) | |
if filtered_df_gen.empty: | |
st.warning(f"โ ๏ธ ์ ํํ์ ์นดํ ๊ณ ๋ฆฌ '{category_filter}'์ ํด๋นํ๋ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") | |
progress_bar.empty() | |
status_text_gen.empty() | |
return None | |
filtered_indices = filtered_df_gen.index.tolist() | |
# loaded_embeddings์์ ์ง์ ์ธ๋ฑ์ฑํ๊ธฐ ์ ์, filtered_indices๊ฐ loaded_embeddings์ ๋ฒ์ ๋ด์ ์๋์ง ํ์ธ | |
valid_indices_for_embedding = [idx for idx in filtered_indices if idx < len(loaded_embeddings)] | |
if not valid_indices_for_embedding: | |
st.warning(f"โ ๏ธ ์ ํจํ ์ธ๋ฑ์ค๋ฅผ ์ฐพ์ ์ ์์ด ์ ์ฌ๋ ๊ฒ์์ ์งํํ ์ ์์ต๋๋ค. (์นดํ ๊ณ ๋ฆฌ: {category_filter})") | |
progress_bar.empty() | |
status_text_gen.empty() | |
return None | |
# ์ ํจํ ์ธ๋ฑ์ค์ ํด๋นํ๋ ์๋ฒ ๋ฉ๋ง ์ฌ์ฉ | |
# ์ด ๋ถ๋ถ์ ์๋ณธ ๋ฐ์ดํฐํ๋ ์(loaded_df)์ ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํด์ผ ํจ | |
# filtered_df_gen์ ์ธ๋ฑ์ค๋ loaded_df์ ๋ถ๋ถ์งํฉ์ด๋ฏ๋ก, | |
# loaded_embeddings์์ ์ด ์ธ๋ฑ์ค๋ค์ ์ง์ ์ฌ์ฉํด์ผ ํฉ๋๋ค. | |
# ์ฃผ์: filtered_indices๋ loaded_df์ ์ค์ ์ธ๋ฑ์ค ๊ฐ์ด์ด์ผ ํจ. | |
# ๋ง์ฝ filtered_df_gen.index๊ฐ 0๋ถํฐ ์์ํ๋ ์๋ก์ด ์ธ๋ฑ์ค๋ผ๋ฉด, ๋งคํ ํ์. | |
# ํ์ฌ ์ฝ๋๋ filtered_df.index.tolist()๊ฐ ์๋ณธ ์ธ๋ฑ์ค๋ฅผ ์ ์งํ๋ค๊ณ ๊ฐ์ . | |
filtered_embeddings_for_search = loaded_embeddings[valid_indices_for_embedding] | |
# ์ ์ฌ๋ ๊ณ์ฐ ์ query_embedding๊ณผ filtered_embeddings_for_search์ ์ฐจ์ ํ์ธ | |
if query_embedding.shape[1] != filtered_embeddings_for_search.shape[1]: | |
st.error(f"โ ์๋ฒ ๋ฉ ์ฐจ์ ๋ถ์ผ์น: ์ฟผ๋ฆฌ({query_embedding.shape[1]}D), ๋ฌธ์({filtered_embeddings_for_search.shape[1]}D)") | |
return None | |
similarities = cosine_similarity(query_embedding, filtered_embeddings_for_search)[0] | |
# ์์ N๊ฐ (num_references) ์ ํ | |
# similarities์ ๊ธธ์ด๋ valid_indices_for_embedding์ ๊ธธ์ด์ ๊ฐ์ | |
# top_indices๋ similarities ๋ฐฐ์ด ๋ด์ ์ธ๋ฑ์ค | |
num_to_select = min(num_references, len(similarities)) | |
top_similarity_indices = np.argsort(similarities)[::-1][:num_to_select] | |
reference_copies = [] | |
for i in top_similarity_indices: | |
# i๋ similarities ๋ฐฐ์ด์์์ ์ธ๋ฑ์ค. | |
# ์ด ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํ์ฌ valid_indices_for_embedding์์ ์๋ณธ ๋ฐ์ดํฐํ๋ ์์ ์ธ๋ฑ์ค๋ฅผ ๊ฐ์ ธ์์ผ ํจ. | |
original_df_idx = valid_indices_for_embedding[i] | |
row = loaded_df.iloc[original_df_idx] # ์๋ณธ df์์ ๊ฐ์ ธ์ด | |
if similarities[i] >= min_similarity: | |
reference_copies.append({ | |
'copy': row['์นดํผ ๋ด์ฉ'], | |
'brand': row['๋ธ๋๋'], | |
'similarity': float(similarities[i]) # float์ผ๋ก ๋ณํ (JSON ์ง๋ ฌํ ๋๋น) | |
}) | |
progress_bar.progress(60) | |
if not reference_copies: | |
st.warning(f"โ ๏ธ ์ ์ฌ๋ {min_similarity} ์ด์์ธ ์ฐธ๊ณ ์นดํผ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ์ ์ฌ๋๋ฅผ ๋ฎ์ถฐ๋ณด์ธ์.") | |
# ์ฐธ๊ณ ์นดํผ๊ฐ ์์ด๋ LLM์๊ฒ ์์ฑ์ ์์ฒญํ ์๋ ์๋๋ก ํจ (์ ํ์ฌํญ) | |
# progress_bar.empty() | |
# status_text_gen.empty() | |
# return None | |
references_text_for_prompt = "์ ์ฌ๋ ๋์ ์ฐธ๊ณ ์นดํผ๋ฅผ ์ฐพ์ง ๋ชปํ์ต๋๋ค." | |
else: | |
references_text_for_prompt = "\n".join([ | |
f"{j+1}. \"{ref['copy']}\" - {ref['brand']} (์ ์ฌ๋: {ref['similarity']:.3f})" | |
for j, ref in enumerate(reference_copies) | |
]) | |
status_text_gen.text("๐ค AI ์นดํผ ์์ฑ ์ค...") | |
progress_bar.progress(80) | |
creativity_guidance = { | |
"๋ณด์์ ": "์์ ํ๊ณ ๊ฒ์ฆ๋ ํํ์ ์ฌ์ฉํ์ฌ", | |
"๊ท ํ": "์ฐฝ์์ ์ด๋ฉด์๋ ์ ์ ํ ์์ค์์", | |
"์ฐฝ์์ ": "๋ ์ฐฝ์ ์ด๊ณ ํ์ ์ ์ธ ํํ์ผ๋ก" | |
} | |
prompt = f""" | |
๋น์ ์ ํ๊ตญ์ ์ ๋ฌธ ๊ด๊ณ ์นดํผ๋ผ์ดํฐ์ ๋๋ค. | |
**์์ฒญ์ฌํญ:** {user_req} | |
**ํ๊ฒ ๊ณ ๊ฐ:** {target_aud} | |
**๋ธ๋๋ ํค:** {brand_tn} | |
**์ฐฝ์์ฑ ์์ค:** {creative_lvl} ({creativity_guidance[creative_lvl]}) | |
**์ฐธ๊ณ ์นดํผ๋ค (์๋ฏธ์ ์ ์ฌ๋ ๊ธฐ๋ฐ ์ ๋ณ):** | |
{references_text_for_prompt} | |
**์์ฑ ๊ฐ์ด๋๋ผ์ธ:** | |
1. ์ ์ฐธ๊ณ ์นดํผ๋ค์ ์คํ์ผ๊ณผ ํค์ ๋ถ์ํ๊ณ , ์์ฒญ์ฌํญ์ ๋ง์ถฐ ์๋ก์ด ์นดํผ {num_con}๊ฐ๋ฅผ ์์ฑํด์ฃผ์ธ์. | |
2. ๋ง์ฝ ์ฐธ๊ณ ์นดํผ๊ฐ ์๋ค๋ฉด, ์์ฒญ์ฌํญ๊ณผ ํ๊ฒ ๊ณ ๊ฐ, ๋ธ๋๋ ํค, ์ฐฝ์์ฑ ์์ค์๋ง ์ง์คํ์ฌ ์์ฑํด์ฃผ์ธ์. | |
3. ๊ฐ ์นดํผ๋ ํ๊ตญ์ด๋ก ์์ฐ์ค๋ฝ๊ณ ๋งค๋ ฅ์ ์ด์ด์ผ ํฉ๋๋ค. | |
4. {target_aud}์๊ฒ ์ดํํ ์ ์๋ ํํ์ ์ฌ์ฉํด์ฃผ์ธ์. | |
5. {brand_tn} ํค์ค๋งค๋๋ฅผ ์ ์งํด์ฃผ์ธ์. | |
**์ถ๋ ฅ ํ์ (๊ฐ ์นดํผ์ ๊ฐ๋จํ ์ค๋ช ํฌํจ):** | |
1. [์์ฑ๋ ์นดํผ 1] | |
- ์ค๋ช : (์ด ์นดํผ๊ฐ ์ ํจ๊ณผ์ ์ธ์ง ๋๋ ์ด๋ค ์๋๋ก ์์ฑ๋์๋์ง) | |
2. [์์ฑ๋ ์นดํผ 2] | |
- ์ค๋ช : (์ด ์นดํผ๊ฐ ์ ํจ๊ณผ์ ์ธ์ง ๋๋ ์ด๋ค ์๋๋ก ์์ฑ๋์๋์ง) | |
(์์ฒญํ ์ปจ์ ์๋งํผ ๋ฐ๋ณต) | |
**์ถ์ฒ ์นดํผ:** (์ ์์ฑ๋ ์นดํผ ์ค ๊ฐ์ฅ ์ถ์ฒํ๋ ๊ฒ ํ๋์ ๊ทธ ์ด์ ) | |
""" | |
response = loaded_model.generate_content(prompt) | |
progress_bar.progress(100) | |
status_text_gen.text("โ ์๋ฃ!") | |
time.sleep(0.5) | |
progress_bar.empty() | |
status_text_gen.empty() | |
return { | |
'references': reference_copies, | |
'generated_content': response.text, | |
'search_info': { | |
'query': search_query, | |
'total_candidates': len(filtered_df_gen), | |
'selected_references': len(reference_copies) | |
}, | |
'settings': { | |
'category': category_filter, | |
'target': target_aud, | |
'tone': brand_tn, | |
'creative': creative_lvl | |
} | |
} | |
except Exception as e_gen: | |
st.error(f"โ ์นดํผ ์์ฑ ์คํจ: {e_gen}") | |
st.error(f"์ค๋ฅ ํ์ : {type(e_gen)}") # ์ค๋ฅ ํ์ ์ถ๋ ฅ | |
import traceback # ์์ธ ํธ๋ ์ด์ค๋ฐฑ | |
st.error(traceback.format_exc()) | |
progress_bar.empty() | |
status_text_gen.empty() | |
return None | |
# ์์ฑ ๋ฒํผ | |
if st.button("๐ ์นดํผ ์์ฑํ๊ธฐ", type="primary", use_container_width=True, key="generate_button"): | |
if not user_request or not user_request.strip(): | |
st.error("โ ์นดํผ ์์ฒญ์ ์ ๋ ฅํด์ฃผ์ธ์") | |
else: | |
result = generate_copy_with_rag( | |
user_req=user_request, | |
category_filter=selected_category, | |
target_aud=target_audience, | |
brand_tn=brand_tone, | |
creative_lvl=creative_level, | |
num_con=num_concepts | |
) | |
if result: | |
st.markdown("## ๐ ์์ฑ๋ ์นดํผ") | |
st.markdown("---") | |
st.info(f"๐ **๊ฒ์ ์ ๋ณด**: {result['search_info']['total_candidates']:,}๊ฐ ํ๋ณด์์ " | |
f"{result['search_info']['selected_references']}๊ฐ ์ฐธ๊ณ ์นดํผ ์ ๋ณ") | |
if show_references and result['references']: | |
with st.expander("๐ ์ฐธ๊ณ ํ ์นดํผ๋ค (์๋ฏธ์ ์ ์ฌ๋ ๊ธฐ๋ฐ ์ ๋ณ)"): | |
for i, ref in enumerate(result['references'], 1): | |
st.markdown(f"**{i}.** \"{ref['copy']}\"") | |
st.markdown(f" - ๋ธ๋๋: {ref['brand']}") | |
st.markdown(f" - ์ ์ฌ๋: {ref['similarity']:.3f}") | |
st.markdown("") | |
st.markdown("### โจ AI๊ฐ ์์ฑํ ์นดํผ:") | |
st.markdown(result['generated_content']) | |
try: | |
result_json = json.dumps({ | |
'timestamp': datetime.now().isoformat(), | |
'request': user_request, | |
'settings': result['settings'], | |
'search_info': result['search_info'], | |
'generated_content': result['generated_content'], | |
'references': result['references'] # ์ฐธ๊ณ ์นดํผ๋ JSON์ ํฌํจ | |
}, ensure_ascii=False, indent=2) | |
st.download_button( | |
label="๐พ ๊ฒฐ๊ณผ ๋ค์ด๋ก๋ (JSON)", | |
data=result_json, | |
file_name=f"copy_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", | |
mime="application/json", | |
key="download_button" | |
) | |
except Exception as e_json: | |
st.error(f"โ ๊ฒฐ๊ณผ ๋ค์ด๋ก๋ ํ์ผ ์์ฑ ์คํจ: {e_json}") | |
# ์์คํ ์ ๋ณด (์ฌ์ด๋๋ฐ ํ๋จ) | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("### ๐ RAG ์์คํ ์ ๋ณด") | |
if loaded_df is not None and loaded_embeddings is not None: | |
st.sidebar.markdown(f"**์นดํผ ๋ฐ์ดํฐ**: {len(loaded_df):,}๊ฐ") | |
st.sidebar.markdown(f"**์นดํ ๊ณ ๋ฆฌ**: {loaded_df['์นดํ ๊ณ ๋ฆฌ'].nunique()}๊ฐ") | |
st.sidebar.markdown(f"**๋ธ๋๋**: {loaded_df['๋ธ๋๋'].nunique()}๊ฐ") | |
st.sidebar.markdown(f"**์๋ฒ ๋ฉ**: {loaded_embeddings.shape[1]}์ฐจ์") # ๋ก๋๋ ์๋ฒ ๋ฉ ์ฌ์ฉ | |
st.sidebar.markdown("**๊ฒ์ ์์ง**: Korean SBERT") | |
st.sidebar.markdown("**ํธ์คํ **: ๐ค Hugging Face") | |
# ์ฌ์ฉ๋ฒ ๊ฐ์ด๋ | |
with st.expander("๐ก RAG ์์คํ ์ฌ์ฉ๋ฒ ๊ฐ์ด๋"): | |
st.markdown(""" | |
### ๐ฏ ํจ๊ณผ์ ์ธ ์ฌ์ฉ๋ฒ | |
(๊ธฐ์กด ๋ด์ฉ๊ณผ ๋์ผ) | |
""") | |
# ํธํฐ | |
st.markdown("---") | |
st.markdown( | |
"๐ก **AI ์นดํผ๋ผ์ดํฐ** | 37,671๊ฐ ์ค์ ๊ด๊ณ ์นดํผ ๋ฐ์ดํฐ ๊ธฐ๋ฐ | " | |
"RAG(๊ฒ์ ์ฆ๊ฐ ์์ฑ) ์์คํ powered by Korean SBERT + Gemini AI" | |
) | |
# ์ฑ๋ฅ ๋ชจ๋ํฐ๋ง (๊ฐ๋ฐ์์ฉ) | |
if os.getenv("DEBUG_MODE") == "true": # ํ๊ฒฝ๋ณ์ ๊ฐ์ ๋ฌธ์์ด "true"๋ก ๋น๊ต | |
st.sidebar.markdown("### ๐ง ๋๋ฒ๊ทธ ์ ๋ณด (ํ์ฑํ๋จ)") | |
if 'loaded_embeddings' in locals() and loaded_embeddings is not None: # ๋ก๋๋ ๋ณ์ ์ฌ์ฉ | |
st.sidebar.write(f"์๋ฒ ๋ฉ ๋ฉ๋ชจ๋ฆฌ: {loaded_embeddings.nbytes / (1024*1024):.1f}MB") | |
st.sidebar.write(f"Streamlit ๋ฒ์ : {st.__version__}") | |
st.sidebar.write(f"Pandas ๋ฒ์ : {pd.__version__}") | |
st.sidebar.write(f"Numpy ๋ฒ์ (Global): {np.__version__ if 'np' in globals() else 'Not imported globally'}") | |
st.sidebar.write(f"Torch ๋ฒ์ : {torch.__version__ if 'torch' in globals() else 'Torch not directly used here'}") # torch๋ sentence-transformers ๋ด๋ถ ์ฌ์ฉ | |
st.sidebar.write(f"google-generativeai ๋ฒ์ : {genai.__version__}") |