fromdk's picture
Upload folder using huggingface_hub
c8b9067 verified
from enum import Enum
from typing import List, TypedDict, Annotated
from pydantic import BaseModel, Field
from decimal import Decimal
import ast
import re
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_community.tools import QuerySQLDatabaseTool
from langchain_community.utilities import SQLDatabase
from langchain_ollama import OllamaEmbeddings
from langchain.agents.agent_toolkits import create_retriever_tool
from langgraph.graph import START, StateGraph
import gradio as gr
##################################################################
# ν™˜κ²½ μ„€μ • / λ°μ΄ν„°λ² μ΄μŠ€ μ—°κ²°
##################################################################
from dotenv import load_dotenv
load_dotenv()
db = SQLDatabase.from_uri("sqlite:///etf_database.db")
##################################################################
# 고유λͺ…사 DB 검색
##################################################################
def query_as_list(db, query):
res = db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
return list(set(res))
etfs = query_as_list(db, "SELECT DISTINCT μ’…λͺ©λͺ… FROM ETFs")
fund_managers = query_as_list(db, "SELECT DISTINCT μš΄μš©μ‚¬ FROM ETFs")
underlying_assets = query_as_list(db, "SELECT DISTINCT κΈ°μ΄ˆμ§€μˆ˜ FROM ETFs")
# μž„λ² λ”© λͺ¨λΈ 생성
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
# μž„λ² λ”© 벑터 μ €μž₯μ†Œ 생성
vector_store = InMemoryVectorStore(embeddings)
# ETF μ’…λͺ©λͺ…κ³Ό μš΄μš©μ‚¬λ₯Ό μž„λ² λ”© λ²‘ν„°λ‘œ λ³€ν™˜
_ = vector_store.add_texts(etfs + fund_managers + underlying_assets)
retriever = vector_store.as_retriever(search_kwargs={"k": 10})
# 검색 ν”„λ‘¬ν”„νŠΈ 생성
description = (
"Use to look up values to filter on. Input is an approximate spelling "
"of the proper noun, output is valid proper nouns. Use the noun most "
"similar to the search."
)
# 검색 도ꡬ 생성
entity_retriever_tool = create_retriever_tool(
retriever,
name="search_proper_nouns",
description=description,
)
##################################################################
# μƒνƒœ 정보 νƒ€μž… μ •μ˜
##################################################################
class State(TypedDict):
question: str # μ‚¬μš©μž μž…λ ₯ 질문
user_profile: dict # μ‚¬μš©μž ν”„λ‘œν•„ 정보
query: str # μƒμ„±λœ SQL 쿼리
candidates: list # 후보 ETF λͺ©λ‘
rankings: list # μˆœμœ„κ°€ 맀겨진 ETF λͺ©λ‘
explanation: str # μΆ”μ²œ 이유 μ„€λͺ…
final_answer: str # μ΅œμ’… μΆ”μ²œ λ‹΅λ³€
##################################################################
# μ‚¬μš©μž ν”„λ‘œν•„ 뢄석
##################################################################
class RiskTolerance(str, Enum):
CONSERVATIVE = "conservative"
MODERATE = "moderate"
AGGRESSIVE = "aggressive"
class InvestmentHorizon(str, Enum):
SHORT = "short"
MEDIUM = "medium"
LONG = "long"
class InvestmentProfile(BaseModel):
risk_tolerance: RiskTolerance = Field(
description="투자자의 μœ„ν—˜ μ„±ν–₯ (conservative/moderate/aggressive)"
)
investment_horizon: InvestmentHorizon = Field(
description="투자 κΈ°κ°„ (short/medium/long)"
)
investment_goal: str = Field(
description="투자의 μ£Όμš” λͺ©μ  μ„€λͺ…"
)
preferred_sectors: List[str] = Field(
description="μ„ ν˜Έν•˜λŠ” 투자 μ„Ήν„° λͺ©λ‘"
)
excluded_sectors: List[str] = Field(
description="투자λ₯Ό μ›ν•˜μ§€ μ•ŠλŠ” μ„Ήν„° λͺ©λ‘"
)
monthly_investment: int = Field(
description="μ›” 투자 κ°€λŠ₯ κΈˆμ•‘ (원)"
)
# μ‚¬μš©μž ν”„λ‘œν•„ 뢄석 ν”„λ‘¬ν”„νŠΈ
PROFILE_TEMPLATE= """
μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ„ λΆ„μ„ν•˜μ—¬ 투자 ν”„λ‘œν•„μ„ μƒμ„±ν•΄μ£Όμ„Έμš”.
μ‚¬μš©μž 질문: {question}
"""
profile_prompt = ChatPromptTemplate.from_template(PROFILE_TEMPLATE)
# μ‚¬μš©μž ν”„λ‘œν•„ 뢄석 λͺ¨λΈ 생성
llm = ChatOpenAI(model="gpt-4.1-mini")
profile_llm = llm.with_structured_output(InvestmentProfile)
# μ‚¬μš©μž ν”„λ‘œν•„ 뢄석 ν•¨μˆ˜
def analyze_profile(state: State) -> dict:
"""μ‚¬μš©μž μ§ˆλ¬Έμ„ λΆ„μ„ν•˜μ—¬ 투자 ν”„λ‘œν•„ 생성"""
prompt = profile_prompt.invoke({"question": state["question"]})
response = profile_llm.invoke(prompt)
return {"user_profile": dict(response)}
##################################################################
# SQL 쿼리 생성
##################################################################
# SQL Query Generation Template
QUERY_TEMPLATE = """
Given an input question and investment profile, create a syntactically correct {dialect} query to run. Unless specified, limit your query to at most {top_k} results. Order the results by most relevant columns based on the investment profile.
Never query for all columns from a specific table, only ask for relevant columns given the question and investment criteria.
Pay attention to use only the column names you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Available tables:
{table_info}
Entity relationships:
{entity_info}
- Use exact matches when comparing entity names
- Check for historical name variations if available
- Apply case-sensitive matching for official names
- Handle both Korean and English entity names when present
Investment Profile:
{user_profile}
Question: {input}
Important:
1. Use only existing columns
2. Query only necessary columns (no SELECT *)
3. Follow correct table relationships
4. Consider performance and indexing
"""
# SQL Query Generation Prompt Template
query_prompt_template = ChatPromptTemplate.from_template(QUERY_TEMPLATE)
# SQL Query Output
class QueryOutput(TypedDict):
"""Generated SQL query."""
query: Annotated[str, ..., "Syntactically valid SQL query."]
explanation: Annotated[str, ..., "Query explanation and selection criteria (in ν•œκ΅­μ–΄)"]
def write_query(state: State):
"""Generate SQL query to fetch information."""
prompt = query_prompt_template.invoke(
{
"dialect": db.dialect,
"top_k": 10,
"table_info": db.get_table_info(),
"input": state["question"],
"entity_info": entity_retriever_tool.invoke(state["question"]),
"user_profile": state["user_profile"],
}
)
structured_llm = llm.with_structured_output(QueryOutput)
result = structured_llm.invoke(prompt)
return {"query": result["query"], "explanation": result["explanation"]}
##################################################################
# 후보 ETF 검색
##################################################################
def execute_query(state: State) -> dict:
"""SQL 쿼리 μ‹€ν–‰ν•˜μ—¬ 후보 ETF 검색"""
execute_query_tool = QuerySQLDatabaseTool(db=db)
results = execute_query_tool.invoke(state["query"])
return {"candidates": results}
##################################################################
# ETF μˆœμœ„ λ§€κΈ°κΈ°
##################################################################
RANKING_TEMPLATE = """
Rank the following ETF candidates based on the user's investment profile and return the top 3(three) ETFs.
Consider these factors when ranking:
1. 수읡λ₯ 
2. 변동성
3. μˆœμžμ‚°μ΄μ•‘
4. 변동성
5. User Profile matching score
User Profile:
{user_profile}
Candidate ETFs:
{candidates}
"""
# ETF Ranking Prompt Template
ranking_prompt = ChatPromptTemplate.from_template(RANKING_TEMPLATE)
# ETF Ranking Output
class ETFRanking(TypedDict):
"""Individual ETF ranking result"""
rank: Annotated[int, ..., "Ranking position (1-5)"]
etf_code: Annotated[str, ..., "ETF μ’…λͺ©μ½”λ“œ (6-digit)"]
etf_name: Annotated[str, ..., "ETF μ’…λͺ©λͺ…"]
score: Annotated[float, ..., "Composite score (0-100)"]
ranking_reason: Annotated[str, ..., "Explanation for the ranking (in ν•œκ΅­μ–΄)"]
class ETFRankingResult(TypedDict):
"""Ranked ETFs"""
rankings: List[ETFRanking]
# ETF Ranking Function
def rank_etfs(state: State) -> dict:
"""Rank ETF candidates based on user's investment profile"""
prompt = ranking_prompt.invoke(
{
"user_profile": state["user_profile"],
"candidates": state["candidates"],
}
)
structured_llm = llm.with_structured_output(ETFRankingResult)
results = structured_llm.invoke(prompt)
return {"rankings": results}
##################################################################
# μΆ”μ²œ 이유 μ„€λͺ…
##################################################################
EXPLANATION_TEMPLATE = """
Please provide a comprehensive explanation for the recommended ETFs based on the user's investment profile.
[GUIDELINES]
1. ETF Characteristics
- Investment strategy and approach
- Historical performance overview
- Fee structure and efficiency
- Underlying assets and diversification
2. Profile Fit Analysis
- Alignment with risk tolerance
- Match with investment horizon
- Sector preference compatibility
- Investment goal contribution
3. Portfolio Construction
- Recommended allocation percentages
- Diversification benefits
- Rebalancing considerations
- Implementation strategy
4. Risk Considerations
- Market risk factors
- Specific ETF risks
- Economic scenario impacts
- Monitoring requirements
--------------------------------------------
[User Profile]
{user_profile}
[Selected ETFs]
{rankings}
"""
# μΆ”μ²œ μ„€λͺ… ν”„λ‘¬ν”„νŠΈ
explanation_prompt = ChatPromptTemplate.from_template(EXPLANATION_TEMPLATE)
# μΆ”μ²œ μ„€λͺ… 좜λ ₯ μŠ€ν‚€λ§ˆ
class ETFRecommendation(BaseModel):
"""Individual ETF recommendation details"""
etf_code: str = Field(..., description="ETF μ’…λͺ©μ½”λ“œ (6-digit)")
etf_name: str = Field(..., description="ETF μ’…λͺ©λͺ…")
allocation: Decimal = Field(..., description="Recommended allocation % (0-100)")
description: str = Field(..., description="ETF description and investment strategy (in ν•œκ΅­μ–΄)")
key_points: List[str] = Field(..., description="Key investment points (in ν•œκ΅­μ–΄)")
risks: List[str] = Field(..., description="Risk factors to consider (in ν•œκ΅­μ–΄)")
class RecommendationExplanation(BaseModel):
"""ETF recommendation explanation with markdown formatting"""
overview: str = Field(..., description="Overall strategy explanation (in ν•œκ΅­μ–΄)")
recommendations: List[ETFRecommendation] = Field(..., description="ETF details")
considerations: List[str] = Field(..., description="Important considerations (in ν•œκ΅­μ–΄)")
# λ§ˆν¬λ‹€μš΄ 포맷으둜 좜λ ₯
def to_markdown(self) -> str:
"""Convert explanation to markdown format"""
markdown = [
"# ETF 포트폴리였 μΆ”μ²œ",
"",
"## 투자 μ „λž΅ κ°œμš”",
self.overview,
"",
"## μΆ”μ²œ ETF 포트폴리였",
""
]
# 포트폴리였 ꡬ성 λΉ„μœ¨
markdown.extend([
"| ETF | μ’…λͺ©μ½”λ“œ | μΆ”μ²œλΉ„μ€‘ |",
"|-----|----------|----------|"
])
for rec in self.recommendations:
markdown.append(
f"| {rec.etf_name} | {rec.etf_code} | {rec.allocation}% |"
)
# ETF 상세 μ„€λͺ…
markdown.append("\n## ETF 상세 μ„€λͺ…\n")
for rec in self.recommendations:
markdown.extend([
f"### {rec.etf_name} ({rec.etf_code})",
rec.description,
"",
"**μ£Όμš” 투자 포인트:**",
"".join([f"\n* {point}" for point in rec.key_points]),
"",
"**투자 μœ„ν—˜:**",
"".join([f"\n* {risk}" for risk in rec.risks]),
""
])
# 투자 리슀크 고렀사항
markdown.extend([
"## 투자 μ‹œ 고렀사항",
"".join([f"\n* {item}" for item in self.considerations]),
""
])
return "\n".join(markdown)
# μΆ”μ²œ μ„€λͺ… 생성 ν•¨μˆ˜
def generate_explanation(state: dict) -> dict:
"""Generate structured ETF recommendation explanation"""
# ν”„λ‘¬ν”„νŠΈ 생성
prompt = explanation_prompt.invoke({
"rankings": state["rankings"],
"user_profile": state["user_profile"]
})
# κ΅¬μ‘°ν™”λœ 좜λ ₯ 생성
structured_llm = llm.with_structured_output(RecommendationExplanation)
response = structured_llm.invoke(prompt)
return {"final_answer": {
"explanation": response.model_dump(),
"markdown": response.to_markdown()
}}
##################################################################
# ETF μΆ”μ²œ 봇 - μƒνƒœ κ·Έλž˜ν”„ 생성
##################################################################
# μƒνƒœ κ·Έλž˜ν”„ 생성
graph_builder = StateGraph(State).add_sequence(
[analyze_profile, write_query, execute_query, rank_etfs, generate_explanation]
)
graph_builder.add_edge(START, "analyze_profile")
graph = graph_builder.compile()
##################################################################
# ETF μΆ”μ²œ 봇 - 메인 ν•¨μˆ˜
##################################################################
def process_message(message: str) -> str:
try:
etf_recommendation = graph.invoke(
{"question": message}
)
return etf_recommendation["final_answer"]["markdown"]
except Exception as e:
return f"""
# 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€
μ£„μ†‘ν•©λ‹ˆλ‹€. μš”μ²­μ„ μ²˜λ¦¬ν•˜λŠ” 쀑에 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.
였λ₯˜ λ‚΄μš©: {str(e)}
λ‹€μ‹œ μ‹œλ„ν•΄μ£Όμ‹œκ±°λ‚˜, μ§ˆλ¬Έμ„ λ‹€λ₯Έ λ°©μ‹μœΌλ‘œ μž‘μ„±ν•΄μ£Όμ„Έμš”.
"""
def answer_invoke(message: str, history: List) -> str:
return process_message(message) # λ©”μ‹œμ§€ 처리 ν•¨μˆ˜ 호좜 - λŒ€ν™” 이λ ₯ λ―Έμ‚¬μš©
# Create Gradio interface
demo = gr.ChatInterface(
fn=answer_invoke,
title="λ§žμΆ€ν˜• ETF μΆ”μ²œ μ–΄μ‹œμŠ€ν„΄νŠΈ",
description="""
투자 μ„±ν–₯κ³Ό λͺ©ν‘œμ— λ§žλŠ” ETFλ₯Ό μΆ”μ²œν•΄λ“œλ¦½λ‹ˆλ‹€.
λ‹€μŒκ³Ό 같은 정보λ₯Ό ν¬ν•¨ν•˜μ—¬ μ§ˆλ¬Έν•΄μ£Όμ„Έμš”:
- 투자 λͺ©μ 
- 투자 κΈ°κ°„
- μœ„ν—˜ μ„±ν–₯
- μ„ ν˜Έ/μ œμ™Έ μ„Ήν„°
- μ›” 투자 κ°€λŠ₯ κΈˆμ•‘
μ˜ˆμ‹œ) "μ›” 100λ§Œμ› 정도λ₯Ό 3λ…„ 이상 μž₯κΈ° νˆ¬μžν•˜κ³  μ‹Άκ³ , IT와 ν—¬μŠ€μΌ€μ–΄ μ„Ήν„°λ₯Ό μ„ ν˜Έν•©λ‹ˆλ‹€.
보수적인 투자λ₯Ό μ„ ν˜Έν•˜λ©°, λ‹΄λ°° κ΄€λ ¨ 기업은 μ œμ™Έν•˜κ³  μ‹ΆμŠ΅λ‹ˆλ‹€."
""",
examples=[
"""20λŒ€ ν›„λ°˜μ˜ λŒ€ν•™μƒμž…λ‹ˆλ‹€.
μ›” 50λ§Œμ› 정도λ₯Ό 1λ…„ 이상 μž₯κΈ° νˆ¬μžν•˜κ³  μ‹Άκ³ ,
ν™˜μœ¨κ³Ό κΈˆλ¦¬μ— 관심이 μžˆμŠ΅λ‹ˆλ‹€.
κ³ μœ„ν—˜ κ³ μˆ˜μ΅μ„ μΆ”κ΅¬ν•˜λ©°, ESG μš”μ†Œλ„ κ³ λ €ν•˜κ³  μ‹ΆμŠ΅λ‹ˆλ‹€.
μ ν•©ν•œ ETFλ₯Ό μΆ”μ²œν•΄μ£Όμ„Έμš”."""
],
type="messages",
)
# μΈν„°νŽ˜μ΄μŠ€ μ‹€ν–‰
if __name__ == "__main__":
demo.launch()