hybridRAG / app.py
soojeongcrystal's picture
Update app.py
dd3cae2 verified
raw
history blame
8.04 kB
import gradio as gr
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
import matplotlib.pyplot as plt
import csv
import io
import matplotlib.font_manager as fm
from neo4j import GraphDatabase
# ν•œκ΅­μ–΄ 처리λ₯Ό μœ„ν•œ KoSentence-BERT λͺ¨λΈ λ‘œλ“œ
model = SentenceTransformer('jhgan/ko-sbert-sts')
# λ‚˜λˆ”λ°”λ₯Έκ³ λ”• 폰트 μ„€μ •
font_path = "NanumBarunGothic.ttf" # Hugging Face λ£¨νŠΈμ— μ €μž₯된 폰트 경둜
fontprop = fm.FontProperties(fname=font_path)
plt.rc('font', family=fontprop.get_name())
# Neo4j λ°μ΄ν„°λ² μ΄μŠ€ μ—°κ²° 클래슀
class Neo4jConnection:
def __init__(self, uri, user, pwd):
self.driver = GraphDatabase.driver(uri, auth=(user, pwd))
def close(self):
self.driver.close()
def query(self, query, parameters=None, db=None):
session = None
response = None
try:
session = self.driver.session(database=db) if db else self.driver.session()
response = list(session.run(query, parameters))
except Exception as e:
print("Query failed:", e)
finally:
if session:
session.close()
return response
# Neo4j μ—°κ²° μ„€μ •
conn = Neo4jConnection(uri="bolt://localhost:7687", user="neo4j", pwd="your_password")
# μΆ”μ²œ κ²°κ³Όλ₯Ό μ‹€μ œ 파일둜 μ €μž₯ν•˜λŠ” ν•¨μˆ˜
def save_recommendations_to_file(recommendations):
file_path = "recommendations.csv"
with open(file_path, mode='w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(["Employee ID", "Employee Name", "Recommended Programs"])
# μΆ”μ²œ κ²°κ³Ό CSV νŒŒμΌμ— 기둝
for rec in recommendations:
writer.writerow(rec)
return file_path
# μžλ™μœΌλ‘œ 열을 λ§€μΉ­ν•˜λŠ” ν•¨μˆ˜
def auto_match_columns(df, required_cols):
matched_cols = {}
for req_col in required_cols:
matched_col = None
for col in df.columns:
if req_col in col.lower():
matched_col = col
break
matched_cols[req_col] = matched_col
return matched_cols
# 직원 및 ν”„λ‘œκ·Έλž¨ λ°μ΄ν„°μ˜ 열을 μžλ™μœΌλ‘œ λ§€μΉ­ν•˜κ±°λ‚˜, μ„ νƒν•˜κ²Œ ν•˜λŠ” ν•¨μˆ˜
def validate_and_get_columns(employee_df, program_df):
required_employee_cols = ["employee_id", "employee_name", "current_skills"]
required_program_cols = ["program_name", "skills_acquired", "duration"]
employee_cols = auto_match_columns(employee_df, required_employee_cols)
program_cols = auto_match_columns(program_df, required_program_cols)
for key, value in employee_cols.items():
if value is None:
return f"직원 λ°μ΄ν„°μ—μ„œ '{key}' 열을 선택할 수 μ—†μŠ΅λ‹ˆλ‹€. μ˜¬λ°”λ₯Έ 열을 μ„ νƒν•˜μ„Έμš”.", None, None
for key, value in program_cols.items():
if value is None:
return f"ν”„λ‘œκ·Έλž¨ λ°μ΄ν„°μ—μ„œ '{key}' 열을 선택할 수 μ—†μŠ΅λ‹ˆλ‹€. μ˜¬λ°”λ₯Έ 열을 μ„ νƒν•˜μ„Έμš”.", None, None
return None, employee_cols, program_cols
# 직원 데이터λ₯Ό λΆ„μ„ν•˜μ—¬ ꡐ윑 ν”„λ‘œκ·Έλž¨μ„ μΆ”μ²œν•˜κ³ , ν…Œμ΄λΈ”κ³Ό κ·Έλž˜ν”„λ₯Ό μƒμ„±ν•˜λŠ” ν•¨μˆ˜
def hybrid_rag(employee_file, program_file):
# 1. VectorRAG: KoSentence-BERTλ₯Ό μ΄μš©ν•œ μœ μ‚¬λ„ 계산
employee_df = pd.read_csv(employee_file.name)
program_df = pd.read_csv(program_file.name)
error_msg, employee_cols, program_cols = validate_and_get_columns(employee_df, program_df)
if error_msg:
return error_msg, None, None, None
employee_skills = employee_df[employee_cols["current_skills"]].tolist()
program_skills = program_df[program_cols["skills_acquired"]].tolist()
employee_embeddings = model.encode(employee_skills)
program_embeddings = model.encode(program_skills)
similarities = cosine_similarity(employee_embeddings, program_embeddings)
recommendations = []
recommendation_rows = [] # ν…Œμ΄λΈ” 및 CSV둜 μ €μž₯ν•  데이터
for i, employee in employee_df.iterrows():
recommended_programs = []
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5:
recommended_programs.append(f"{program[program_cols['program_name']]} ({program[program_cols['duration']]})")
if recommended_programs:
recommendation = f"직원 {employee[employee_cols['employee_name']]}의 μΆ”μ²œ ν”„λ‘œκ·Έλž¨: {', '.join(recommended_programs)}"
recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']], ", ".join(recommended_programs)])
else:
recommendation = f"직원 {employee[employee_cols['employee_name']]}μ—κ²Œ μ ν•©ν•œ ν”„λ‘œκ·Έλž¨μ΄ μ—†μŠ΅λ‹ˆλ‹€."
recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']], "μ ν•©ν•œ ν”„λ‘œκ·Έλž¨ μ—†μŒ"])
recommendations.append(recommendation)
# 2. GraphRAG: Neo4jμ—μ„œ ν”„λ‘œκ·Έλž¨ μΆ”μ²œμ„ κ°€μ Έμ˜΄
query = """
MATCH (e:Employee)-[:HAS_SKILL]->(p:Program)
RETURN e.name AS employee_name, p.name AS program_name, p.duration AS duration
"""
graph_rag_results = conn.query(query)
# GraphRAG κ²°κ³Ό μΆ”κ°€
for record in graph_rag_results:
for row in recommendation_rows:
if record['employee_name'] == row[1]:
row[2] += f", {record['program_name']} (GraphRAG)"
G = nx.Graph()
for employee in employee_df[employee_cols['employee_name']]:
G.add_node(employee, type='employee')
for program in program_df[program_cols['program_name']]:
G.add_node(program, type='program')
for i, employee in employee_df.iterrows():
for j, program in program_df.iterrows():
if similarities[i][j] > 0.5:
G.add_edge(employee[employee_cols['employee_name']], program[program_cols['program_name']])
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=3000, font_size=10, font_weight='bold', edge_color='gray', fontproperties=fontprop)
plt.title("직원과 ν”„λ‘œκ·Έλž¨ κ°„μ˜ 관계", fontsize=14, fontweight='bold', fontproperties=fontprop)
plt.tight_layout()
# CSV 파일둜 μΆ”μ²œ κ²°κ³Ό λ°˜ν™˜
csv_output = save_recommendations_to_file(recommendation_rows)
# κ²°κ³Ό ν…Œμ΄λΈ” λ°μ΄ν„°ν”„λ ˆμž„ 생성
result_df = pd.DataFrame(recommendation_rows, columns=["Employee ID", "Employee Name", "Recommended Programs"])
return result_df, plt.gcf(), csv_output
# Gradio 블둝
with gr.Blocks(css=".gradio-button {background-color: #007bff; color: white;} .gradio-textbox {border-color: #6c757d;}") as demo:
gr.Markdown("<h1 style='text-align: center; color: #2c3e50;'>πŸ’Ό HybridRAG μ‹œμŠ€ν…œ</h1>")
with gr.Row():
with gr.Column(scale=1, min_width=300):
gr.Markdown("<h3 style='color: #34495e;'>1. 직원 및 ν”„λ‘œκ·Έλž¨ 데이터λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš”</h3>")
employee_file = gr.File(label="직원 데이터 μ—…λ‘œλ“œ", interactive=True)
program_file = gr.File(label="ꡐ윑 ν”„λ‘œκ·Έλž¨ 데이터 μ—…λ‘œλ“œ", interactive=True)
analyze_button = gr.Button("뢄석 μ‹œμž‘", elem_classes="gradio-button")
output_table = gr.DataFrame(label="뢄석 κ²°κ³Ό (ν…Œμ΄λΈ”)")
csv_download = gr.File(label="μΆ”μ²œ κ²°κ³Ό λ‹€μš΄λ‘œλ“œ")
with gr.Column(scale=2, min_width=500):
gr.Markdown("<h3 style='color: #34495e;'>2. 뢄석 κ²°κ³Ό 및 μ‹œκ°ν™”</h3>")
chart_output = gr.Plot(label="μ‹œκ°ν™” 차트")
# 뢄석 λ²„νŠΌ 클릭 μ‹œ ν…Œμ΄λΈ”, 차트, 파일 λ‹€μš΄λ‘œλ“œλ₯Ό μ—…λ°μ΄νŠΈ
analyze_button.click(hybrid_rag, inputs=[employee_file, program_file], outputs=[output_table, chart_output, csv_download])
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ‹€ν–‰
demo.launch()