Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import csv | |
import io | |
import matplotlib.font_manager as fm | |
from neo4j import GraphDatabase | |
# νκ΅μ΄ μ²λ¦¬λ₯Ό μν KoSentence-BERT λͺ¨λΈ λ‘λ | |
model = SentenceTransformer('jhgan/ko-sbert-sts') | |
# λλλ°λ₯Έκ³ λ ν°νΈ μ€μ | |
font_path = "NanumBarunGothic.ttf" # Hugging Face 루νΈμ μ μ₯λ ν°νΈ κ²½λ‘ | |
fontprop = fm.FontProperties(fname=font_path) | |
plt.rc('font', family=fontprop.get_name()) | |
# Neo4j λ°μ΄ν°λ² μ΄μ€ μ°κ²° ν΄λμ€ | |
class Neo4jConnection: | |
def __init__(self, uri, user, pwd): | |
self.driver = GraphDatabase.driver(uri, auth=(user, pwd)) | |
def close(self): | |
self.driver.close() | |
def query(self, query, parameters=None, db=None): | |
session = None | |
response = None | |
try: | |
session = self.driver.session(database=db) if db else self.driver.session() | |
response = list(session.run(query, parameters)) | |
except Exception as e: | |
print("Query failed:", e) | |
finally: | |
if session: | |
session.close() | |
return response | |
# Neo4j μ°κ²° μ€μ | |
conn = Neo4jConnection(uri="bolt://localhost:7687", user="neo4j", pwd="your_password") | |
# μΆμ² κ²°κ³Όλ₯Ό μ€μ νμΌλ‘ μ μ₯νλ ν¨μ | |
def save_recommendations_to_file(recommendations): | |
file_path = "recommendations.csv" | |
with open(file_path, mode='w', newline='', encoding='utf-8') as file: | |
writer = csv.writer(file) | |
writer.writerow(["Employee ID", "Employee Name", "Recommended Programs"]) | |
# μΆμ² κ²°κ³Ό CSV νμΌμ κΈ°λ‘ | |
for rec in recommendations: | |
writer.writerow(rec) | |
return file_path | |
# μλμΌλ‘ μ΄μ λ§€μΉνλ ν¨μ | |
def auto_match_columns(df, required_cols): | |
matched_cols = {} | |
for req_col in required_cols: | |
matched_col = None | |
for col in df.columns: | |
if req_col in col.lower(): | |
matched_col = col | |
break | |
matched_cols[req_col] = matched_col | |
return matched_cols | |
# μ§μ λ° νλ‘κ·Έλ¨ λ°μ΄ν°μ μ΄μ μλμΌλ‘ λ§€μΉνκ±°λ, μ ννκ² νλ ν¨μ | |
def validate_and_get_columns(employee_df, program_df): | |
required_employee_cols = ["employee_id", "employee_name", "current_skills"] | |
required_program_cols = ["program_name", "skills_acquired", "duration"] | |
employee_cols = auto_match_columns(employee_df, required_employee_cols) | |
program_cols = auto_match_columns(program_df, required_program_cols) | |
for key, value in employee_cols.items(): | |
if value is None: | |
return f"μ§μ λ°μ΄ν°μμ '{key}' μ΄μ μ νν μ μμ΅λλ€. μ¬λ°λ₯Έ μ΄μ μ ννμΈμ.", None, None | |
for key, value in program_cols.items(): | |
if value is None: | |
return f"νλ‘κ·Έλ¨ λ°μ΄ν°μμ '{key}' μ΄μ μ νν μ μμ΅λλ€. μ¬λ°λ₯Έ μ΄μ μ ννμΈμ.", None, None | |
return None, employee_cols, program_cols | |
# μ§μ λ°μ΄ν°λ₯Ό λΆμνμ¬ κ΅μ‘ νλ‘κ·Έλ¨μ μΆμ²νκ³ , ν μ΄λΈκ³Ό κ·Έλνλ₯Ό μμ±νλ ν¨μ | |
def hybrid_rag(employee_file, program_file): | |
# 1. VectorRAG: KoSentence-BERTλ₯Ό μ΄μ©ν μ μ¬λ κ³μ° | |
employee_df = pd.read_csv(employee_file.name) | |
program_df = pd.read_csv(program_file.name) | |
error_msg, employee_cols, program_cols = validate_and_get_columns(employee_df, program_df) | |
if error_msg: | |
return error_msg, None, None, None | |
employee_skills = employee_df[employee_cols["current_skills"]].tolist() | |
program_skills = program_df[program_cols["skills_acquired"]].tolist() | |
employee_embeddings = model.encode(employee_skills) | |
program_embeddings = model.encode(program_skills) | |
similarities = cosine_similarity(employee_embeddings, program_embeddings) | |
recommendations = [] | |
recommendation_rows = [] # ν μ΄λΈ λ° CSVλ‘ μ μ₯ν λ°μ΄ν° | |
for i, employee in employee_df.iterrows(): | |
recommended_programs = [] | |
for j, program in program_df.iterrows(): | |
if similarities[i][j] > 0.5: | |
recommended_programs.append(f"{program[program_cols['program_name']]} ({program[program_cols['duration']]})") | |
if recommended_programs: | |
recommendation = f"μ§μ {employee[employee_cols['employee_name']]}μ μΆμ² νλ‘κ·Έλ¨: {', '.join(recommended_programs)}" | |
recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']], ", ".join(recommended_programs)]) | |
else: | |
recommendation = f"μ§μ {employee[employee_cols['employee_name']]}μκ² μ ν©ν νλ‘κ·Έλ¨μ΄ μμ΅λλ€." | |
recommendation_rows.append([employee[employee_cols['employee_id']], employee[employee_cols['employee_name']], "μ ν©ν νλ‘κ·Έλ¨ μμ"]) | |
recommendations.append(recommendation) | |
# 2. GraphRAG: Neo4jμμ νλ‘κ·Έλ¨ μΆμ²μ κ°μ Έμ΄ | |
query = """ | |
MATCH (e:Employee)-[:HAS_SKILL]->(p:Program) | |
RETURN e.name AS employee_name, p.name AS program_name, p.duration AS duration | |
""" | |
graph_rag_results = conn.query(query) | |
# GraphRAG κ²°κ³Ό μΆκ° | |
for record in graph_rag_results: | |
for row in recommendation_rows: | |
if record['employee_name'] == row[1]: | |
row[2] += f", {record['program_name']} (GraphRAG)" | |
G = nx.Graph() | |
for employee in employee_df[employee_cols['employee_name']]: | |
G.add_node(employee, type='employee') | |
for program in program_df[program_cols['program_name']]: | |
G.add_node(program, type='program') | |
for i, employee in employee_df.iterrows(): | |
for j, program in program_df.iterrows(): | |
if similarities[i][j] > 0.5: | |
G.add_edge(employee[employee_cols['employee_name']], program[program_cols['program_name']]) | |
plt.figure(figsize=(10, 8)) | |
pos = nx.spring_layout(G) | |
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=3000, font_size=10, font_weight='bold', edge_color='gray', fontproperties=fontprop) | |
plt.title("μ§μκ³Ό νλ‘κ·Έλ¨ κ°μ κ΄κ³", fontsize=14, fontweight='bold', fontproperties=fontprop) | |
plt.tight_layout() | |
# CSV νμΌλ‘ μΆμ² κ²°κ³Ό λ°ν | |
csv_output = save_recommendations_to_file(recommendation_rows) | |
# κ²°κ³Ό ν μ΄λΈ λ°μ΄ν°νλ μ μμ± | |
result_df = pd.DataFrame(recommendation_rows, columns=["Employee ID", "Employee Name", "Recommended Programs"]) | |
return result_df, plt.gcf(), csv_output | |
# Gradio λΈλ‘ | |
with gr.Blocks(css=".gradio-button {background-color: #007bff; color: white;} .gradio-textbox {border-color: #6c757d;}") as demo: | |
gr.Markdown("<h1 style='text-align: center; color: #2c3e50;'>πΌ HybridRAG μμ€ν </h1>") | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
gr.Markdown("<h3 style='color: #34495e;'>1. μ§μ λ° νλ‘κ·Έλ¨ λ°μ΄ν°λ₯Ό μ λ‘λνμΈμ</h3>") | |
employee_file = gr.File(label="μ§μ λ°μ΄ν° μ λ‘λ", interactive=True) | |
program_file = gr.File(label="κ΅μ‘ νλ‘κ·Έλ¨ λ°μ΄ν° μ λ‘λ", interactive=True) | |
analyze_button = gr.Button("λΆμ μμ", elem_classes="gradio-button") | |
output_table = gr.DataFrame(label="λΆμ κ²°κ³Ό (ν μ΄λΈ)") | |
csv_download = gr.File(label="μΆμ² κ²°κ³Ό λ€μ΄λ‘λ") | |
with gr.Column(scale=2, min_width=500): | |
gr.Markdown("<h3 style='color: #34495e;'>2. λΆμ κ²°κ³Ό λ° μκ°ν</h3>") | |
chart_output = gr.Plot(label="μκ°ν μ°¨νΈ") | |
# λΆμ λ²νΌ ν΄λ¦ μ ν μ΄λΈ, μ°¨νΈ, νμΌ λ€μ΄λ‘λλ₯Ό μ λ°μ΄νΈ | |
analyze_button.click(hybrid_rag, inputs=[employee_file, program_file], outputs=[output_table, chart_output, csv_download]) | |
# Gradio μΈν°νμ΄μ€ μ€ν | |
demo.launch() |