Spaces:
Sleeping
Sleeping
from enum import Enum | |
from typing import List, TypedDict, Annotated | |
from pydantic import BaseModel, Field | |
from decimal import Decimal | |
import ast | |
import re | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain_core.vectorstores import InMemoryVectorStore | |
from langchain_community.tools import QuerySQLDatabaseTool | |
from langchain_community.utilities import SQLDatabase | |
from langchain_ollama import OllamaEmbeddings | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langgraph.graph import START, StateGraph | |
import gradio as gr | |
################################################################## | |
# νκ²½ μ€μ / λ°μ΄ν°λ² μ΄μ€ μ°κ²° | |
################################################################## | |
from dotenv import load_dotenv | |
load_dotenv() | |
db = SQLDatabase.from_uri("sqlite:///etf_database.db") | |
################################################################## | |
# κ³ μ λͺ μ¬ DB κ²μ | |
################################################################## | |
def query_as_list(db, query): | |
res = db.run(query) | |
res = [el for sub in ast.literal_eval(res) for el in sub if el] | |
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res] | |
return list(set(res)) | |
etfs = query_as_list(db, "SELECT DISTINCT μ’ λͺ©λͺ FROM ETFs") | |
fund_managers = query_as_list(db, "SELECT DISTINCT μ΄μ©μ¬ FROM ETFs") | |
underlying_assets = query_as_list(db, "SELECT DISTINCT κΈ°μ΄μ§μ FROM ETFs") | |
# μλ² λ© λͺ¨λΈ μμ± | |
embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
# μλ² λ© λ²‘ν° μ μ₯μ μμ± | |
vector_store = InMemoryVectorStore(embeddings) | |
# ETF μ’ λͺ©λͺ κ³Ό μ΄μ©μ¬λ₯Ό μλ² λ© λ²‘ν°λ‘ λ³ν | |
_ = vector_store.add_texts(etfs + fund_managers + underlying_assets) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 10}) | |
# κ²μ ν둬ννΈ μμ± | |
description = ( | |
"Use to look up values to filter on. Input is an approximate spelling " | |
"of the proper noun, output is valid proper nouns. Use the noun most " | |
"similar to the search." | |
) | |
# κ²μ λꡬ μμ± | |
entity_retriever_tool = create_retriever_tool( | |
retriever, | |
name="search_proper_nouns", | |
description=description, | |
) | |
################################################################## | |
# μν μ 보 νμ μ μ | |
################################################################## | |
class State(TypedDict): | |
question: str # μ¬μ©μ μ λ ₯ μ§λ¬Έ | |
user_profile: dict # μ¬μ©μ νλ‘ν μ 보 | |
query: str # μμ±λ SQL 쿼리 | |
candidates: list # ν보 ETF λͺ©λ‘ | |
rankings: list # μμκ° λ§€κ²¨μ§ ETF λͺ©λ‘ | |
explanation: str # μΆμ² μ΄μ μ€λͺ | |
final_answer: str # μ΅μ’ μΆμ² λ΅λ³ | |
################################################################## | |
# μ¬μ©μ νλ‘ν λΆμ | |
################################################################## | |
class RiskTolerance(str, Enum): | |
CONSERVATIVE = "conservative" | |
MODERATE = "moderate" | |
AGGRESSIVE = "aggressive" | |
class InvestmentHorizon(str, Enum): | |
SHORT = "short" | |
MEDIUM = "medium" | |
LONG = "long" | |
class InvestmentProfile(BaseModel): | |
risk_tolerance: RiskTolerance = Field( | |
description="ν¬μμμ μν μ±ν₯ (conservative/moderate/aggressive)" | |
) | |
investment_horizon: InvestmentHorizon = Field( | |
description="ν¬μ κΈ°κ° (short/medium/long)" | |
) | |
investment_goal: str = Field( | |
description="ν¬μμ μ£Όμ λͺ©μ μ€λͺ " | |
) | |
preferred_sectors: List[str] = Field( | |
description="μ νΈνλ ν¬μ μΉν° λͺ©λ‘" | |
) | |
excluded_sectors: List[str] = Field( | |
description="ν¬μλ₯Ό μνμ§ μλ μΉν° λͺ©λ‘" | |
) | |
monthly_investment: int = Field( | |
description="μ ν¬μ κ°λ₯ κΈμ‘ (μ)" | |
) | |
# μ¬μ©μ νλ‘ν λΆμ ν둬ννΈ | |
PROFILE_TEMPLATE= """ | |
μ¬μ©μμ μ§λ¬Έμ λΆμνμ¬ ν¬μ νλ‘νμ μμ±ν΄μ£ΌμΈμ. | |
μ¬μ©μ μ§λ¬Έ: {question} | |
""" | |
profile_prompt = ChatPromptTemplate.from_template(PROFILE_TEMPLATE) | |
# μ¬μ©μ νλ‘ν λΆμ λͺ¨λΈ μμ± | |
llm = ChatOpenAI(model="gpt-4.1-mini") | |
profile_llm = llm.with_structured_output(InvestmentProfile) | |
# μ¬μ©μ νλ‘ν λΆμ ν¨μ | |
def analyze_profile(state: State) -> dict: | |
"""μ¬μ©μ μ§λ¬Έμ λΆμνμ¬ ν¬μ νλ‘ν μμ±""" | |
prompt = profile_prompt.invoke({"question": state["question"]}) | |
response = profile_llm.invoke(prompt) | |
return {"user_profile": dict(response)} | |
################################################################## | |
# SQL 쿼리 μμ± | |
################################################################## | |
# SQL Query Generation Template | |
QUERY_TEMPLATE = """ | |
Given an input question and investment profile, create a syntactically correct {dialect} query to run. Unless specified, limit your query to at most {top_k} results. Order the results by most relevant columns based on the investment profile. | |
Never query for all columns from a specific table, only ask for relevant columns given the question and investment criteria. | |
Pay attention to use only the column names you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. | |
Available tables: | |
{table_info} | |
Entity relationships: | |
{entity_info} | |
- Use exact matches when comparing entity names | |
- Check for historical name variations if available | |
- Apply case-sensitive matching for official names | |
- Handle both Korean and English entity names when present | |
Investment Profile: | |
{user_profile} | |
Question: {input} | |
Important: | |
1. Use only existing columns | |
2. Query only necessary columns (no SELECT *) | |
3. Follow correct table relationships | |
4. Consider performance and indexing | |
""" | |
# SQL Query Generation Prompt Template | |
query_prompt_template = ChatPromptTemplate.from_template(QUERY_TEMPLATE) | |
# SQL Query Output | |
class QueryOutput(TypedDict): | |
"""Generated SQL query.""" | |
query: Annotated[str, ..., "Syntactically valid SQL query."] | |
explanation: Annotated[str, ..., "Query explanation and selection criteria (in νκ΅μ΄)"] | |
def write_query(state: State): | |
"""Generate SQL query to fetch information.""" | |
prompt = query_prompt_template.invoke( | |
{ | |
"dialect": db.dialect, | |
"top_k": 10, | |
"table_info": db.get_table_info(), | |
"input": state["question"], | |
"entity_info": entity_retriever_tool.invoke(state["question"]), | |
"user_profile": state["user_profile"], | |
} | |
) | |
structured_llm = llm.with_structured_output(QueryOutput) | |
result = structured_llm.invoke(prompt) | |
return {"query": result["query"], "explanation": result["explanation"]} | |
################################################################## | |
# ν보 ETF κ²μ | |
################################################################## | |
def execute_query(state: State) -> dict: | |
"""SQL 쿼리 μ€ννμ¬ ν보 ETF κ²μ""" | |
execute_query_tool = QuerySQLDatabaseTool(db=db) | |
results = execute_query_tool.invoke(state["query"]) | |
return {"candidates": results} | |
################################################################## | |
# ETF μμ λ§€κΈ°κΈ° | |
################################################################## | |
RANKING_TEMPLATE = """ | |
Rank the following ETF candidates based on the user's investment profile and return the top 3(three) ETFs. | |
Consider these factors when ranking: | |
1. μμ΅λ₯ | |
2. λ³λμ± | |
3. μμμ°μ΄μ‘ | |
4. λ³λμ± | |
5. User Profile matching score | |
User Profile: | |
{user_profile} | |
Candidate ETFs: | |
{candidates} | |
""" | |
# ETF Ranking Prompt Template | |
ranking_prompt = ChatPromptTemplate.from_template(RANKING_TEMPLATE) | |
# ETF Ranking Output | |
class ETFRanking(TypedDict): | |
"""Individual ETF ranking result""" | |
rank: Annotated[int, ..., "Ranking position (1-5)"] | |
etf_code: Annotated[str, ..., "ETF μ’ λͺ©μ½λ (6-digit)"] | |
etf_name: Annotated[str, ..., "ETF μ’ λͺ©λͺ "] | |
score: Annotated[float, ..., "Composite score (0-100)"] | |
ranking_reason: Annotated[str, ..., "Explanation for the ranking (in νκ΅μ΄)"] | |
class ETFRankingResult(TypedDict): | |
"""Ranked ETFs""" | |
rankings: List[ETFRanking] | |
# ETF Ranking Function | |
def rank_etfs(state: State) -> dict: | |
"""Rank ETF candidates based on user's investment profile""" | |
prompt = ranking_prompt.invoke( | |
{ | |
"user_profile": state["user_profile"], | |
"candidates": state["candidates"], | |
} | |
) | |
structured_llm = llm.with_structured_output(ETFRankingResult) | |
results = structured_llm.invoke(prompt) | |
return {"rankings": results} | |
################################################################## | |
# μΆμ² μ΄μ μ€λͺ | |
################################################################## | |
EXPLANATION_TEMPLATE = """ | |
Please provide a comprehensive explanation for the recommended ETFs based on the user's investment profile. | |
[GUIDELINES] | |
1. ETF Characteristics | |
- Investment strategy and approach | |
- Historical performance overview | |
- Fee structure and efficiency | |
- Underlying assets and diversification | |
2. Profile Fit Analysis | |
- Alignment with risk tolerance | |
- Match with investment horizon | |
- Sector preference compatibility | |
- Investment goal contribution | |
3. Portfolio Construction | |
- Recommended allocation percentages | |
- Diversification benefits | |
- Rebalancing considerations | |
- Implementation strategy | |
4. Risk Considerations | |
- Market risk factors | |
- Specific ETF risks | |
- Economic scenario impacts | |
- Monitoring requirements | |
-------------------------------------------- | |
[User Profile] | |
{user_profile} | |
[Selected ETFs] | |
{rankings} | |
""" | |
# μΆμ² μ€λͺ ν둬ννΈ | |
explanation_prompt = ChatPromptTemplate.from_template(EXPLANATION_TEMPLATE) | |
# μΆμ² μ€λͺ μΆλ ₯ μ€ν€λ§ | |
class ETFRecommendation(BaseModel): | |
"""Individual ETF recommendation details""" | |
etf_code: str = Field(..., description="ETF μ’ λͺ©μ½λ (6-digit)") | |
etf_name: str = Field(..., description="ETF μ’ λͺ©λͺ ") | |
allocation: Decimal = Field(..., description="Recommended allocation % (0-100)") | |
description: str = Field(..., description="ETF description and investment strategy (in νκ΅μ΄)") | |
key_points: List[str] = Field(..., description="Key investment points (in νκ΅μ΄)") | |
risks: List[str] = Field(..., description="Risk factors to consider (in νκ΅μ΄)") | |
class RecommendationExplanation(BaseModel): | |
"""ETF recommendation explanation with markdown formatting""" | |
overview: str = Field(..., description="Overall strategy explanation (in νκ΅μ΄)") | |
recommendations: List[ETFRecommendation] = Field(..., description="ETF details") | |
considerations: List[str] = Field(..., description="Important considerations (in νκ΅μ΄)") | |
# λ§ν¬λ€μ΄ ν¬λ§·μΌλ‘ μΆλ ₯ | |
def to_markdown(self) -> str: | |
"""Convert explanation to markdown format""" | |
markdown = [ | |
"# ETF ν¬νΈν΄λ¦¬μ€ μΆμ²", | |
"", | |
"## ν¬μ μ λ΅ κ°μ", | |
self.overview, | |
"", | |
"## μΆμ² ETF ν¬νΈν΄λ¦¬μ€", | |
"" | |
] | |
# ν¬νΈν΄λ¦¬μ€ κ΅¬μ± λΉμ¨ | |
markdown.extend([ | |
"| ETF | μ’ λͺ©μ½λ | μΆμ²λΉμ€ |", | |
"|-----|----------|----------|" | |
]) | |
for rec in self.recommendations: | |
markdown.append( | |
f"| {rec.etf_name} | {rec.etf_code} | {rec.allocation}% |" | |
) | |
# ETF μμΈ μ€λͺ | |
markdown.append("\n## ETF μμΈ μ€λͺ \n") | |
for rec in self.recommendations: | |
markdown.extend([ | |
f"### {rec.etf_name} ({rec.etf_code})", | |
rec.description, | |
"", | |
"**μ£Όμ ν¬μ ν¬μΈνΈ:**", | |
"".join([f"\n* {point}" for point in rec.key_points]), | |
"", | |
"**ν¬μ μν:**", | |
"".join([f"\n* {risk}" for risk in rec.risks]), | |
"" | |
]) | |
# ν¬μ 리μ€ν¬ κ³ λ €μ¬ν | |
markdown.extend([ | |
"## ν¬μ μ κ³ λ €μ¬ν", | |
"".join([f"\n* {item}" for item in self.considerations]), | |
"" | |
]) | |
return "\n".join(markdown) | |
# μΆμ² μ€λͺ μμ± ν¨μ | |
def generate_explanation(state: dict) -> dict: | |
"""Generate structured ETF recommendation explanation""" | |
# ν둬ννΈ μμ± | |
prompt = explanation_prompt.invoke({ | |
"rankings": state["rankings"], | |
"user_profile": state["user_profile"] | |
}) | |
# ꡬ쑰νλ μΆλ ₯ μμ± | |
structured_llm = llm.with_structured_output(RecommendationExplanation) | |
response = structured_llm.invoke(prompt) | |
return {"final_answer": { | |
"explanation": response.model_dump(), | |
"markdown": response.to_markdown() | |
}} | |
################################################################## | |
# ETF μΆμ² λ΄ - μν κ·Έλν μμ± | |
################################################################## | |
# μν κ·Έλν μμ± | |
graph_builder = StateGraph(State).add_sequence( | |
[analyze_profile, write_query, execute_query, rank_etfs, generate_explanation] | |
) | |
graph_builder.add_edge(START, "analyze_profile") | |
graph = graph_builder.compile() | |
################################################################## | |
# ETF μΆμ² λ΄ - λ©μΈ ν¨μ | |
################################################################## | |
def process_message(message: str) -> str: | |
try: | |
etf_recommendation = graph.invoke( | |
{"question": message} | |
) | |
return etf_recommendation["final_answer"]["markdown"] | |
except Exception as e: | |
return f""" | |
# μ€λ₯κ° λ°μνμ΅λλ€ | |
μ£μ‘ν©λλ€. μμ²μ μ²λ¦¬νλ μ€μ λ¬Έμ κ° λ°μνμ΅λλ€. | |
μ€λ₯ λ΄μ©: {str(e)} | |
λ€μ μλν΄μ£Όμκ±°λ, μ§λ¬Έμ λ€λ₯Έ λ°©μμΌλ‘ μμ±ν΄μ£ΌμΈμ. | |
""" | |
def answer_invoke(message: str, history: List) -> str: | |
return process_message(message) # λ©μμ§ μ²λ¦¬ ν¨μ νΈμΆ - λν μ΄λ ₯ λ―Έμ¬μ© | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
fn=answer_invoke, | |
title="λ§μΆ€ν ETF μΆμ² μ΄μμ€ν΄νΈ", | |
description=""" | |
ν¬μ μ±ν₯κ³Ό λͺ©νμ λ§λ ETFλ₯Ό μΆμ²ν΄λ립λλ€. | |
λ€μκ³Ό κ°μ μ 보λ₯Ό ν¬ν¨νμ¬ μ§λ¬Έν΄μ£ΌμΈμ: | |
- ν¬μ λͺ©μ | |
- ν¬μ κΈ°κ° | |
- μν μ±ν₯ | |
- μ νΈ/μ μΈ μΉν° | |
- μ ν¬μ κ°λ₯ κΈμ‘ | |
μμ) "μ 100λ§μ μ λλ₯Ό 3λ μ΄μ μ₯κΈ° ν¬μνκ³ μΆκ³ , ITμ ν¬μ€μΌμ΄ μΉν°λ₯Ό μ νΈν©λλ€. | |
보μμ μΈ ν¬μλ₯Ό μ νΈνλ©°, λ΄λ°° κ΄λ ¨ κΈ°μ μ μ μΈνκ³ μΆμ΅λλ€." | |
""", | |
examples=[ | |
"""20λ νλ°μ λνμμ λλ€. | |
μ 50λ§μ μ λλ₯Ό 1λ μ΄μ μ₯κΈ° ν¬μνκ³ μΆκ³ , | |
νμ¨κ³Ό κΈλ¦¬μ κ΄μ¬μ΄ μμ΅λλ€. | |
κ³ μν κ³ μμ΅μ μΆκ΅¬νλ©°, ESG μμλ κ³ λ €νκ³ μΆμ΅λλ€. | |
μ ν©ν ETFλ₯Ό μΆμ²ν΄μ£ΌμΈμ.""" | |
], | |
type="messages", | |
) | |
# μΈν°νμ΄μ€ μ€ν | |
if __name__ == "__main__": | |
demo.launch() |