|
from fastapi import FastAPI, UploadFile, File, HTTPException, status |
|
from pydantic import BaseModel |
|
import httpx |
|
import os |
|
from dotenv import load_dotenv |
|
from langgraph.graph import StateGraph, END |
|
from typing import Dict, List, Optional, Any, Union |
|
import logging |
|
import json |
|
|
|
load_dotenv() |
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="Orchestrator (Generalized)") |
|
|
|
AGENT_API_URL = os.getenv("AGENT_API_URL", "http://localhost:8001") |
|
AGENT_SCRAPING_URL = os.getenv("AGENT_SCRAPING_URL", "http://localhost:8002") |
|
AGENT_RETRIEVER_URL = os.getenv("AGENT_RETRIEVER_URL", "http://localhost:8003") |
|
AGENT_ANALYSIS_URL = os.getenv("AGENT_ANALYSIS_URL", "http://localhost:8004") |
|
AGENT_LANGUAGE_URL = os.getenv("AGENT_LANGUAGE_URL", "http://localhost:8005") |
|
AGENT_VOICE_URL = os.getenv("AGENT_VOICE_URL", "http://localhost:8006") |
|
|
|
|
|
class EarningsSurpriseRecordState(BaseModel): |
|
date: str |
|
symbol: str |
|
actual: Union[float, int, str, None] = None |
|
estimate: Union[float, int, str, None] = None |
|
difference: Union[float, int, str, None] = None |
|
surprisePercentage: Union[float, int, str, None] = None |
|
|
|
|
|
class MarketBriefState(BaseModel): |
|
audio_input: Optional[bytes] = None |
|
user_text: Optional[str] = None |
|
nlu_results: Optional[Dict[str, str]] = None |
|
|
|
target_tickers_for_data_fetch: List[str] = [] |
|
market_data: Optional[Dict[str, Dict[str, float]]] = None |
|
filings: Optional[Dict[str, List[EarningsSurpriseRecordState]]] = None |
|
|
|
indexed: bool = False |
|
retrieved_docs: Optional[List[str]] = None |
|
analysis: Optional[Dict[str, Any]] = None |
|
brief: Optional[str] = None |
|
audio_output: Optional[bytes] = None |
|
errors: List[str] = [] |
|
warnings: List[str] = [] |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
|
|
EXAMPLE_PORTFOLIO_FILE = "example_portfolio.json" |
|
EXAMPLE_METADATA_FILE = "example_metadata.json" |
|
|
|
|
|
def load_example_data(file_path: str, default_data: Dict) -> Dict: |
|
if os.path.exists(file_path): |
|
try: |
|
with open(file_path, "r") as f: |
|
return json.load(f) |
|
except Exception as e: |
|
logger.warning(f"Could not load {file_path}: {e}. Using default.") |
|
return default_data |
|
|
|
|
|
EXAMPLE_PORTFOLIO = load_example_data( |
|
EXAMPLE_PORTFOLIO_FILE, |
|
{ |
|
"TSM": { |
|
"weight": 0.22, |
|
"country": "Taiwan", |
|
"sector": "Technology", |
|
}, |
|
"AAPL": {"weight": 0.15, "country": "USA", "sector": "Technology"}, |
|
"MSFT": {"weight": 0.10, "country": "USA", "sector": "Technology"}, |
|
"JNJ": {"weight": 0.08, "country": "USA", "sector": "Healthcare"}, |
|
"BABA": { |
|
"weight": 0.05, |
|
"country": "China", |
|
"sector": "Technology", |
|
}, |
|
}, |
|
) |
|
|
|
|
|
async def call_agent( |
|
client: httpx.AsyncClient, |
|
url: str, |
|
method: str = "post", |
|
json_payload: Optional[Dict] = None, |
|
files_payload: Optional[Dict] = None, |
|
timeout: float = 60.0, |
|
) -> Dict: |
|
try: |
|
logger.info( |
|
f"Calling agent at {url} with payload keys: {list(json_payload.keys()) if json_payload else 'N/A'}" |
|
) |
|
if method == "post": |
|
if json_payload: |
|
response = await client.post(url, json=json_payload, timeout=timeout) |
|
elif files_payload: |
|
response = await client.post(url, files=files_payload, timeout=timeout) |
|
else: |
|
raise ValueError("POST request requires json_payload or files_payload.") |
|
elif method == "get": |
|
response = await client.get(url, params=json_payload, timeout=timeout) |
|
else: |
|
raise ValueError(f"Unsupported method: {method}") |
|
|
|
response.raise_for_status() |
|
logger.info(f"Agent at {url} returned status {response.status_code}.") |
|
return response.json() |
|
|
|
except httpx.ConnectError as e: |
|
error_msg = f"Connection error calling agent at {url}: {e}" |
|
logger.error(error_msg) |
|
raise HTTPException( |
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=error_msg |
|
) |
|
except httpx.RequestError as e: |
|
error_msg = f"Request error calling agent at {url}: {e}" |
|
logger.error(error_msg) |
|
raise HTTPException( |
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=error_msg |
|
) |
|
except httpx.HTTPStatusError as e: |
|
error_msg = f"HTTP error calling agent at {url}: {e.response.status_code} - {e.response.text[:200]}" |
|
logger.error(error_msg) |
|
raise HTTPException(status_code=e.response.status_code, detail=e.response.text) |
|
except Exception as e: |
|
error_msg = f"An unexpected error occurred calling agent at {url}: {e}" |
|
logger.error(error_msg, exc_info=True) |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg |
|
) |
|
|
|
|
|
async def stt_node(state: MarketBriefState) -> MarketBriefState: |
|
|
|
async with httpx.AsyncClient() as client: |
|
if not state.audio_input: |
|
state.errors.append("STT Node: No audio input provided.") |
|
logger.error(state.errors[-1]) |
|
state.user_text = "Error: No audio provided for STT." |
|
return state |
|
files = {"audio": ("input.wav", state.audio_input, "audio/wav")} |
|
try: |
|
response_data = await call_agent( |
|
client, f"{AGENT_VOICE_URL}/stt", files_payload=files |
|
) |
|
if "transcript" in response_data: |
|
state.user_text = response_data["transcript"] |
|
logger.info(f"STT successful. Transcript: {state.user_text[:50]}...") |
|
else: |
|
error_msg = f"STT agent response missing 'transcript': {response_data}" |
|
logger.error(error_msg) |
|
state.errors.append(error_msg) |
|
state.user_text = "Error: STT failed." |
|
except HTTPException as e: |
|
state.errors.append(f"STT Node failed: {e.detail}") |
|
state.user_text = "Error: STT service unavailable or failed." |
|
return state |
|
|
|
|
|
async def nlu_node(state: MarketBriefState) -> MarketBriefState: |
|
"""(NEW) Calls an NLU process (simulated here) to extract intent.""" |
|
if not state.user_text or "Error:" in state.user_text: |
|
state.warnings.append( |
|
"NLU Node: Skipping due to missing or error in user_text." |
|
) |
|
state.nlu_results = { |
|
"region": "Global", |
|
"sector": "Overall Portfolio", |
|
} |
|
return state |
|
|
|
logger.info(f"NLU Node: Processing query: '{state.user_text}'") |
|
|
|
query_lower = state.user_text.lower() |
|
region = "Global" |
|
sector = "Overall Portfolio" |
|
|
|
if "asia" in query_lower and "tech" in query_lower: |
|
region = "Asia" |
|
sector = "Technology" |
|
logger.info("NLU Simulation: Detected 'Asia' and 'Tech'.") |
|
elif "us" in query_lower or "usa" in query_lower or "america" in query_lower: |
|
region = "USA" |
|
if "tech" in query_lower: |
|
sector = "Technology" |
|
elif "health" in query_lower: |
|
sector = "Healthcare" |
|
logger.info(f"NLU Simulation: Detected Region '{region}', Sector '{sector}'.") |
|
|
|
state.nlu_results = {"region": region, "sector": sector} |
|
logger.info(f"NLU Node: Results: {state.nlu_results}") |
|
|
|
target_tickers = [] |
|
portfolio_keys = list(EXAMPLE_PORTFOLIO.keys()) |
|
|
|
if region == "Global" and ( |
|
sector == "Overall Portfolio" or sector == "Overall Market" |
|
): |
|
target_tickers = portfolio_keys |
|
else: |
|
for ticker, details in EXAMPLE_PORTFOLIO.items(): |
|
matches_region = region == "Global" |
|
if region == "Asia" and details.get("country") in [ |
|
"Taiwan", |
|
"China", |
|
"Korea", |
|
"Japan", |
|
"India", |
|
]: |
|
matches_region = True |
|
elif region == "USA" and details.get("country") == "USA": |
|
matches_region = True |
|
|
|
matches_sector = sector == "Overall Portfolio" or sector == "Overall Market" |
|
if sector.lower() == details.get("sector", "").lower(): |
|
matches_sector = True |
|
|
|
if matches_region and matches_sector: |
|
target_tickers.append(ticker) |
|
|
|
if not target_tickers and portfolio_keys: |
|
logger.warning( |
|
f"NLU filtering yielded no specific tickers for {region}/{sector}, defaulting to all portfolio tickers." |
|
) |
|
target_tickers = portfolio_keys |
|
state.nlu_results["region_effective"] = "Global" |
|
state.nlu_results["sector_effective"] = "Overall Portfolio" |
|
|
|
state.target_tickers_for_data_fetch = list(set(target_tickers)) |
|
logger.info( |
|
f"NLU Node: Target tickers for data fetch: {state.target_tickers_for_data_fetch}" |
|
) |
|
if not state.target_tickers_for_data_fetch: |
|
state.warnings.append( |
|
"NLU Node: No target tickers identified for data fetching based on query and portfolio." |
|
) |
|
|
|
return state |
|
|
|
|
|
async def api_agent_node(state: MarketBriefState) -> MarketBriefState: |
|
if not state.target_tickers_for_data_fetch: |
|
state.warnings.append( |
|
"API Agent Node: No target tickers to fetch market data for. Skipping." |
|
) |
|
state.market_data = {} |
|
return state |
|
|
|
async with httpx.AsyncClient() as client: |
|
payload = { |
|
"tickers": state.target_tickers_for_data_fetch, |
|
"data_type": "adjClose", |
|
} |
|
try: |
|
response_data = await call_agent( |
|
client, f"{AGENT_API_URL}/get_market_data", json_payload=payload |
|
) |
|
state.market_data = response_data.get("result", {}) |
|
logger.info( |
|
f"API Agent successful. Fetched data for tickers: {list(state.market_data.keys()) if state.market_data else 'None'}" |
|
) |
|
if response_data.get("errors"): |
|
state.warnings.append( |
|
f"API Agent reported errors: {response_data['errors']}" |
|
) |
|
if response_data.get("warnings"): |
|
state.warnings.extend(response_data.get("warnings", [])) |
|
|
|
except HTTPException as e: |
|
state.errors.append( |
|
f"API Agent Node failed for tickers {state.target_tickers_for_data_fetch}: {e.detail}" |
|
) |
|
state.market_data = {} |
|
return state |
|
|
|
|
|
async def scraping_agent_node(state: MarketBriefState) -> MarketBriefState: |
|
if not state.target_tickers_for_data_fetch: |
|
state.warnings.append( |
|
"Scraping Agent Node: No target tickers to fetch earnings for. Skipping." |
|
) |
|
state.filings = {} |
|
return state |
|
|
|
async with httpx.AsyncClient() as client: |
|
filings_data: Dict[str, List[Dict[str, Any]]] = {} |
|
for ticker in state.target_tickers_for_data_fetch: |
|
payload = {"ticker": ticker, "filing_type": "earnings_surprise"} |
|
try: |
|
response_data = await call_agent( |
|
client, f"{AGENT_SCRAPING_URL}/get_filings", json_payload=payload |
|
) |
|
|
|
if "data" in response_data and isinstance(response_data["data"], list): |
|
|
|
filings_data[ticker] = response_data["data"] |
|
logger.info( |
|
f"Scraping Agent got {len(response_data['data'])} records for {ticker}." |
|
) |
|
if not response_data["data"]: |
|
logger.info( |
|
f"Scraping Agent for {ticker} returned 0 earnings surprise records." |
|
) |
|
else: |
|
filings_data[ticker] = [] |
|
state.errors.append( |
|
f"Scraping agent for {ticker} returned malformed data: {str(response_data)[:100]}" |
|
) |
|
except HTTPException as e: |
|
state.errors.append( |
|
f"Scraping Agent Node failed for {ticker}: {e.detail}" |
|
) |
|
filings_data[ticker] = [] |
|
state.filings = filings_data |
|
return state |
|
|
|
|
|
async def retriever_agent_node(state: MarketBriefState) -> MarketBriefState: |
|
|
|
async with httpx.AsyncClient() as client: |
|
docs_to_index = [] |
|
if state.filings: |
|
for ( |
|
ticker, |
|
records_list, |
|
) in state.filings.items(): |
|
if records_list: |
|
doc_content = f"Earnings surprise data for {ticker}:\n" + "\n".join( |
|
[ |
|
f"Date: {r.get('date', 'N/A')}, Symbol: {r.get('symbol', 'N/A')}, " |
|
f"Actual: {r.get('actual', 'N/A')}, Estimate: {r.get('estimate', 'N/A')}, " |
|
f"Surprise%: {r.get('surprisePercentage', 'N/A')}" |
|
for r in records_list |
|
] |
|
) |
|
docs_to_index.append(doc_content) |
|
|
|
if docs_to_index: |
|
try: |
|
|
|
pass |
|
except Exception as e: |
|
state.errors.append(f"Retriever indexing failed: {e}") |
|
state.indexed = False |
|
else: |
|
state.indexed = False |
|
logger.info("Retriever: No new documents to index.") |
|
|
|
if state.user_text: |
|
try: |
|
|
|
pass |
|
except Exception as e: |
|
state.errors.append(f"Retriever retrieval failed: {e}") |
|
state.retrieved_docs = [] |
|
else: |
|
state.retrieved_docs = [] |
|
return state |
|
|
|
|
|
async def analysis_agent_node(state: MarketBriefState) -> MarketBriefState: |
|
if not state.market_data and not state.filings: |
|
state.warnings.append( |
|
"Analysis Agent Node: No market data or filings available. Skipping analysis." |
|
) |
|
state.analysis = None |
|
return state |
|
|
|
async with httpx.AsyncClient() as client: |
|
|
|
nlu_res = state.nlu_results if state.nlu_results else {} |
|
region_label = nlu_res.get("region_effective", nlu_res.get("region", "Global")) |
|
sector_label = nlu_res.get( |
|
"sector_effective", nlu_res.get("sector", "Overall Portfolio") |
|
) |
|
|
|
if region_label == "Global" and ( |
|
sector_label == "Overall Portfolio" or sector_label == "Overall Market" |
|
): |
|
target_label_for_analysis = "Overall Portfolio" |
|
else: |
|
target_label_for_analysis = ( |
|
f"{region_label.replace('USA', 'US')} {sector_label} Stocks".strip() |
|
) |
|
|
|
analysis_target_tickers = state.target_tickers_for_data_fetch |
|
|
|
current_portfolio_weights = { |
|
ticker: details["weight"] for ticker, details in EXAMPLE_PORTFOLIO.items() |
|
} |
|
|
|
payload = { |
|
"portfolio": current_portfolio_weights, |
|
"market_data": state.market_data if state.market_data else {}, |
|
"earnings_data": (state.filings if state.filings else {}), |
|
"target_tickers": analysis_target_tickers, |
|
"target_label": target_label_for_analysis, |
|
} |
|
try: |
|
response_data = await call_agent( |
|
client, f"{AGENT_ANALYSIS_URL}/analyze", json_payload=payload |
|
) |
|
|
|
state.analysis = response_data |
|
logger.info( |
|
f"Analysis Agent successful for '{response_data.get('target_label')}'." |
|
) |
|
except HTTPException as e: |
|
state.errors.append(f"Analysis Agent Node failed: {e.detail}") |
|
state.analysis = None |
|
return state |
|
|
|
|
|
async def language_agent_node(state: MarketBriefState) -> MarketBriefState: |
|
|
|
async with httpx.AsyncClient() as client: |
|
if not state.user_text or "Error:" in state.user_text: |
|
state.errors.append("Language Agent: Skipping due to no valid user text.") |
|
state.brief = ( |
|
"I could not understand your query or there was an earlier error." |
|
) |
|
return state |
|
|
|
analysis_payload_for_llm: Dict[str, Any] |
|
if state.analysis and isinstance(state.analysis, dict): |
|
|
|
analysis_payload_for_llm = { |
|
"target_label": state.analysis.get("target_label", "the portfolio"), |
|
"current_allocation": state.analysis.get("current_allocation", 0.0), |
|
"yesterday_allocation": state.analysis.get("yesterday_allocation", 0.0), |
|
"allocation_change_percentage_points": state.analysis.get( |
|
"allocation_change_percentage_points", 0.0 |
|
), |
|
"earnings_surprises_for_target": state.analysis.get( |
|
"earnings_surprises_for_target", [] |
|
), |
|
} |
|
else: |
|
logger.warning( |
|
"Language Agent: Analysis data is missing or not a dict. Using defaults." |
|
) |
|
state.warnings.append( |
|
"Language Agent: Analysis data unavailable, brief will be general." |
|
) |
|
analysis_payload_for_llm = { |
|
"target_label": "the portfolio (analysis data missing)", |
|
"current_allocation": 0.0, |
|
"yesterday_allocation": 0.0, |
|
"allocation_change_percentage_points": 0.0, |
|
"earnings_surprises_for_target": [], |
|
} |
|
|
|
payload = { |
|
"user_query": state.user_text, |
|
"analysis": analysis_payload_for_llm, |
|
"retrieved_docs": state.retrieved_docs if state.retrieved_docs else [], |
|
} |
|
try: |
|
response_data = await call_agent( |
|
client, f"{AGENT_LANGUAGE_URL}/generate_brief", json_payload=payload |
|
) |
|
state.brief = response_data.get("brief") |
|
logger.info(f"Language Agent successful. Brief: {state.brief[:70]}...") |
|
except HTTPException as e: |
|
state.errors.append(f"Language Agent Node failed: {e.detail}") |
|
state.brief = "Sorry, I couldn't generate the brief at this time due to an internal error." |
|
return state |
|
|
|
|
|
async def tts_node(state: MarketBriefState) -> MarketBriefState: |
|
|
|
brief_text_for_tts = state.brief |
|
if state.errors and ( |
|
not state.brief |
|
or "sorry" in state.brief.lower() |
|
or "error" in state.brief.lower() |
|
): |
|
|
|
error_count = len(state.errors) |
|
brief_text_for_tts = f"I encountered {error_count} error{'s' if error_count > 1 else ''} while processing your request. Please check the detailed report." |
|
logger.warning( |
|
f"TTS Node: Generating audio for error summary due to {error_count} errors in state." |
|
) |
|
elif not state.brief: |
|
brief_text_for_tts = "The market brief could not be generated." |
|
logger.warning("TTS Node: No brief text available from language agent.") |
|
state.warnings.append("TTS Node: No brief content to synthesize.") |
|
|
|
if not brief_text_for_tts: |
|
state.audio_output = None |
|
return state |
|
|
|
async with httpx.AsyncClient() as client: |
|
payload = {"text": brief_text_for_tts, "lang": "en"} |
|
try: |
|
response_data = await call_agent( |
|
client, f"{AGENT_VOICE_URL}/tts", json_payload=payload |
|
) |
|
if "audio" in response_data and isinstance(response_data["audio"], str): |
|
state.audio_output = bytes.fromhex(response_data["audio"]) |
|
logger.info("TTS successful. Audio bytes received.") |
|
else: |
|
state.errors.append( |
|
f"TTS Agent response missing or invalid 'audio': {str(response_data)[:100]}" |
|
) |
|
state.audio_output = None |
|
except HTTPException as e: |
|
state.errors.append(f"TTS Node failed: {e.detail}") |
|
state.audio_output = None |
|
return state |
|
|
|
|
|
def build_market_brief_graph(): |
|
builder = StateGraph(MarketBriefState) |
|
builder.add_node("stt", stt_node) |
|
builder.add_node("nlu", nlu_node) |
|
builder.add_node("api_agent", api_agent_node) |
|
builder.add_node("scraping_agent", scraping_agent_node) |
|
builder.add_node("retriever_agent", retriever_agent_node) |
|
builder.add_node("analysis_agent", analysis_agent_node) |
|
builder.add_node("language_agent", language_agent_node) |
|
builder.add_node("tts", tts_node) |
|
|
|
builder.set_entry_point("stt") |
|
builder.add_edge("stt", "nlu") |
|
builder.add_edge("nlu", "api_agent") |
|
builder.add_edge("api_agent", "scraping_agent") |
|
builder.add_edge("scraping_agent", "retriever_agent") |
|
builder.add_edge("retriever_agent", "analysis_agent") |
|
builder.add_edge("analysis_agent", "language_agent") |
|
builder.add_edge("language_agent", "tts") |
|
builder.add_edge("tts", END) |
|
return builder.compile() |
|
|
|
|
|
graph = build_market_brief_graph() |
|
|
|
|
|
@app.post("/market_brief") |
|
async def market_brief(audio: UploadFile = File(...)): |
|
|
|
logger.info("Received request to /market_brief") |
|
if not audio.content_type or not audio.content_type.startswith("audio/"): |
|
raise HTTPException( |
|
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, |
|
detail="Invalid file type.", |
|
) |
|
|
|
current_run_state = MarketBriefState() |
|
try: |
|
current_run_state.audio_input = await audio.read() |
|
except Exception as e: |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail=f"Failed to read audio: {e}", |
|
) |
|
|
|
processed_state: MarketBriefState = current_run_state |
|
|
|
try: |
|
logger.info("Invoking LangGraph workflow...") |
|
|
|
initial_state_dict = current_run_state.model_dump(exclude_none=True) |
|
invocation_result = await graph.ainvoke(initial_state_dict) |
|
|
|
if isinstance(invocation_result, dict): |
|
|
|
processed_state = MarketBriefState(**invocation_result) |
|
logger.info("LangGraph execution finished. State updated.") |
|
else: |
|
logger.error( |
|
f"LangGraph ainvoke returned unexpected type: {type(invocation_result)}. Using partially updated state." |
|
) |
|
|
|
processed_state.errors.append( |
|
f"Internal graph error: result type {type(invocation_result)}" |
|
) |
|
|
|
except HTTPException as e: |
|
logger.error( |
|
f"Graph execution stopped due to HTTPException from an agent: {e.detail}" |
|
) |
|
processed_state.errors.append(f"Agent call failed: {e.detail}") |
|
except Exception as e: |
|
error_msg = f"An unexpected error occurred during graph execution: {e}" |
|
logger.error(error_msg, exc_info=True) |
|
processed_state.errors.append(error_msg) |
|
|
|
response_payload = { |
|
"transcript": processed_state.user_text, |
|
"brief": processed_state.brief, |
|
"audio": ( |
|
processed_state.audio_output.hex() if processed_state.audio_output else None |
|
), |
|
"errors": processed_state.errors, |
|
"warnings": processed_state.warnings, |
|
"status": "success" if not processed_state.errors else "failed", |
|
"message": "Market brief process completed." |
|
+ (" With errors." if processed_state.errors else " Successfully."), |
|
"nlu_detected": processed_state.nlu_results, |
|
"analysis_details": processed_state.analysis, |
|
} |
|
logger.info( |
|
f"Request finished. Status: {response_payload['status']}. Errors: {len(response_payload['errors'])}. Warnings: {len(response_payload['warnings'])}." |
|
) |
|
return response_payload |
|
|