from fastapi import FastAPI, UploadFile, File, HTTPException, status
from pydantic import BaseModel
import httpx
import os
from dotenv import load_dotenv
from langgraph.graph import StateGraph, END
from typing import Dict, List, Optional, Any, Union
import logging
import json
load_dotenv()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI(title="Orchestrator (Generalized)")
AGENT_API_URL = os.getenv("AGENT_API_URL", "http://localhost:8001")
AGENT_SCRAPING_URL = os.getenv("AGENT_SCRAPING_URL", "http://localhost:8002")
AGENT_RETRIEVER_URL = os.getenv("AGENT_RETRIEVER_URL", "http://localhost:8003")
AGENT_ANALYSIS_URL = os.getenv("AGENT_ANALYSIS_URL", "http://localhost:8004")
AGENT_LANGUAGE_URL = os.getenv("AGENT_LANGUAGE_URL", "http://localhost:8005")
AGENT_VOICE_URL = os.getenv("AGENT_VOICE_URL", "http://localhost:8006")
class EarningsSurpriseRecordState(BaseModel):
date: str
symbol: str
actual: Union[float, int, str, None] = None
estimate: Union[float, int, str, None] = None
difference: Union[float, int, str, None] = None
surprisePercentage: Union[float, int, str, None] = None
class MarketBriefState(BaseModel):
audio_input: Optional[bytes] = None
user_text: Optional[str] = None
nlu_results: Optional[Dict[str, str]] = None
target_tickers_for_data_fetch: List[str] = []
market_data: Optional[Dict[str, Dict[str, float]]] = None
filings: Optional[Dict[str, List[EarningsSurpriseRecordState]]] = None
indexed: bool = False
retrieved_docs: Optional[List[str]] = None
analysis: Optional[Dict[str, Any]] = None
brief: Optional[str] = None
audio_output: Optional[bytes] = None
errors: List[str] = []
warnings: List[str] = []
class Config:
arbitrary_types_allowed = True
EXAMPLE_PORTFOLIO_FILE = "example_portfolio.json"
EXAMPLE_METADATA_FILE = "example_metadata.json"
def load_example_data(file_path: str, default_data: Dict) -> Dict:
if os.path.exists(file_path):
try:
with open(file_path, "r") as f:
return json.load(f)
except Exception as e:
logger.warning(f"Could not load {file_path}: {e}. Using default.")
return default_data
EXAMPLE_PORTFOLIO = load_example_data(
EXAMPLE_PORTFOLIO_FILE,
{
"TSM": {
"weight": 0.22,
"country": "Taiwan",
"sector": "Technology",
},
"AAPL": {"weight": 0.15, "country": "USA", "sector": "Technology"},
"MSFT": {"weight": 0.10, "country": "USA", "sector": "Technology"},
"JNJ": {"weight": 0.08, "country": "USA", "sector": "Healthcare"},
"BABA": {
"weight": 0.05,
"country": "China",
"sector": "Technology",
},
},
)
async def call_agent(
client: httpx.AsyncClient,
url: str,
method: str = "post",
json_payload: Optional[Dict] = None,
files_payload: Optional[Dict] = None,
timeout: float = 60.0,
) -> Dict:
try:
logger.info(
f"Calling agent at {url} with payload keys: {list(json_payload.keys()) if json_payload else 'N/A'}"
)
if method == "post":
if json_payload:
response = await client.post(url, json=json_payload, timeout=timeout)
elif files_payload:
response = await client.post(url, files=files_payload, timeout=timeout)
else:
raise ValueError("POST request requires json_payload or files_payload.")
elif method == "get":
response = await client.get(url, params=json_payload, timeout=timeout)
else:
raise ValueError(f"Unsupported method: {method}")
response.raise_for_status()
logger.info(f"Agent at {url} returned status {response.status_code}.")
return response.json()
except httpx.ConnectError as e:
error_msg = f"Connection error calling agent at {url}: {e}"
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=error_msg
)
except httpx.RequestError as e:
error_msg = f"Request error calling agent at {url}: {e}"
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=error_msg
)
except httpx.HTTPStatusError as e:
error_msg = f"HTTP error calling agent at {url}: {e.response.status_code} - {e.response.text[:200]}"
logger.error(error_msg)
raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
except Exception as e:
error_msg = f"An unexpected error occurred calling agent at {url}: {e}"
logger.error(error_msg, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error_msg
)
async def stt_node(state: MarketBriefState) -> MarketBriefState:
async with httpx.AsyncClient() as client:
if not state.audio_input:
state.errors.append("STT Node: No audio input provided.")
logger.error(state.errors[-1])
state.user_text = "Error: No audio provided for STT."
return state
files = {"audio": ("input.wav", state.audio_input, "audio/wav")}
try:
response_data = await call_agent(
client, f"{AGENT_VOICE_URL}/stt", files_payload=files
)
if "transcript" in response_data:
state.user_text = response_data["transcript"]
logger.info(f"STT successful. Transcript: {state.user_text[:50]}...")
else:
error_msg = f"STT agent response missing 'transcript': {response_data}"
logger.error(error_msg)
state.errors.append(error_msg)
state.user_text = "Error: STT failed."
except HTTPException as e:
state.errors.append(f"STT Node failed: {e.detail}")
state.user_text = "Error: STT service unavailable or failed."
return state
async def nlu_node(state: MarketBriefState) -> MarketBriefState:
"""(NEW) Calls an NLU process (simulated here) to extract intent."""
if not state.user_text or "Error:" in state.user_text:
state.warnings.append(
"NLU Node: Skipping due to missing or error in user_text."
)
state.nlu_results = {
"region": "Global",
"sector": "Overall Portfolio",
}
return state
logger.info(f"NLU Node: Processing query: '{state.user_text}'")
query_lower = state.user_text.lower()
region = "Global"
sector = "Overall Portfolio"
if "asia" in query_lower and "tech" in query_lower:
region = "Asia"
sector = "Technology"
logger.info("NLU Simulation: Detected 'Asia' and 'Tech'.")
elif "us" in query_lower or "usa" in query_lower or "america" in query_lower:
region = "USA"
if "tech" in query_lower:
sector = "Technology"
elif "health" in query_lower:
sector = "Healthcare"
logger.info(f"NLU Simulation: Detected Region '{region}', Sector '{sector}'.")
state.nlu_results = {"region": region, "sector": sector}
logger.info(f"NLU Node: Results: {state.nlu_results}")
target_tickers = []
portfolio_keys = list(EXAMPLE_PORTFOLIO.keys())
if region == "Global" and (
sector == "Overall Portfolio" or sector == "Overall Market"
):
target_tickers = portfolio_keys
else:
for ticker, details in EXAMPLE_PORTFOLIO.items():
matches_region = region == "Global"
if region == "Asia" and details.get("country") in [
"Taiwan",
"China",
"Korea",
"Japan",
"India",
]:
matches_region = True
elif region == "USA" and details.get("country") == "USA":
matches_region = True
matches_sector = sector == "Overall Portfolio" or sector == "Overall Market"
if sector.lower() == details.get("sector", "").lower():
matches_sector = True
if matches_region and matches_sector:
target_tickers.append(ticker)
if not target_tickers and portfolio_keys:
logger.warning(
f"NLU filtering yielded no specific tickers for {region}/{sector}, defaulting to all portfolio tickers."
)
target_tickers = portfolio_keys
state.nlu_results["region_effective"] = "Global"
state.nlu_results["sector_effective"] = "Overall Portfolio"
state.target_tickers_for_data_fetch = list(set(target_tickers))
logger.info(
f"NLU Node: Target tickers for data fetch: {state.target_tickers_for_data_fetch}"
)
if not state.target_tickers_for_data_fetch:
state.warnings.append(
"NLU Node: No target tickers identified for data fetching based on query and portfolio."
)
return state
async def api_agent_node(state: MarketBriefState) -> MarketBriefState:
if not state.target_tickers_for_data_fetch:
state.warnings.append(
"API Agent Node: No target tickers to fetch market data for. Skipping."
)
state.market_data = {}
return state
async with httpx.AsyncClient() as client:
payload = {
"tickers": state.target_tickers_for_data_fetch,
"data_type": "adjClose",
}
try:
response_data = await call_agent(
client, f"{AGENT_API_URL}/get_market_data", json_payload=payload
)
state.market_data = response_data.get("result", {})
logger.info(
f"API Agent successful. Fetched data for tickers: {list(state.market_data.keys()) if state.market_data else 'None'}"
)
if response_data.get("errors"):
state.warnings.append(
f"API Agent reported errors: {response_data['errors']}"
)
if response_data.get("warnings"):
state.warnings.extend(response_data.get("warnings", []))
except HTTPException as e:
state.errors.append(
f"API Agent Node failed for tickers {state.target_tickers_for_data_fetch}: {e.detail}"
)
state.market_data = {}
return state
async def scraping_agent_node(state: MarketBriefState) -> MarketBriefState:
if not state.target_tickers_for_data_fetch:
state.warnings.append(
"Scraping Agent Node: No target tickers to fetch earnings for. Skipping."
)
state.filings = {}
return state
async with httpx.AsyncClient() as client:
filings_data: Dict[str, List[Dict[str, Any]]] = {}
for ticker in state.target_tickers_for_data_fetch:
payload = {"ticker": ticker, "filing_type": "earnings_surprise"}
try:
response_data = await call_agent(
client, f"{AGENT_SCRAPING_URL}/get_filings", json_payload=payload
)
if "data" in response_data and isinstance(response_data["data"], list):
filings_data[ticker] = response_data["data"]
logger.info(
f"Scraping Agent got {len(response_data['data'])} records for {ticker}."
)
if not response_data["data"]:
logger.info(
f"Scraping Agent for {ticker} returned 0 earnings surprise records."
)
else:
filings_data[ticker] = []
state.errors.append(
f"Scraping agent for {ticker} returned malformed data: {str(response_data)[:100]}"
)
except HTTPException as e:
state.errors.append(
f"Scraping Agent Node failed for {ticker}: {e.detail}"
)
filings_data[ticker] = []
state.filings = filings_data
return state
async def retriever_agent_node(state: MarketBriefState) -> MarketBriefState:
async with httpx.AsyncClient() as client:
docs_to_index = []
if state.filings:
for (
ticker,
records_list,
) in state.filings.items():
if records_list:
doc_content = f"Earnings surprise data for {ticker}:\n" + "\n".join(
[
f"Date: {r.get('date', 'N/A')}, Symbol: {r.get('symbol', 'N/A')}, "
f"Actual: {r.get('actual', 'N/A')}, Estimate: {r.get('estimate', 'N/A')}, "
f"Surprise%: {r.get('surprisePercentage', 'N/A')}"
for r in records_list
]
)
docs_to_index.append(doc_content)
if docs_to_index:
try:
pass
except Exception as e:
state.errors.append(f"Retriever indexing failed: {e}")
state.indexed = False
else:
state.indexed = False
logger.info("Retriever: No new documents to index.")
if state.user_text:
try:
pass
except Exception as e:
state.errors.append(f"Retriever retrieval failed: {e}")
state.retrieved_docs = []
else:
state.retrieved_docs = []
return state
async def analysis_agent_node(state: MarketBriefState) -> MarketBriefState:
if not state.market_data and not state.filings:
state.warnings.append(
"Analysis Agent Node: No market data or filings available. Skipping analysis."
)
state.analysis = None
return state
async with httpx.AsyncClient() as client:
nlu_res = state.nlu_results if state.nlu_results else {}
region_label = nlu_res.get("region_effective", nlu_res.get("region", "Global"))
sector_label = nlu_res.get(
"sector_effective", nlu_res.get("sector", "Overall Portfolio")
)
if region_label == "Global" and (
sector_label == "Overall Portfolio" or sector_label == "Overall Market"
):
target_label_for_analysis = "Overall Portfolio"
else:
target_label_for_analysis = (
f"{region_label.replace('USA', 'US')} {sector_label} Stocks".strip()
)
analysis_target_tickers = state.target_tickers_for_data_fetch
current_portfolio_weights = {
ticker: details["weight"] for ticker, details in EXAMPLE_PORTFOLIO.items()
}
payload = {
"portfolio": current_portfolio_weights,
"market_data": state.market_data if state.market_data else {},
"earnings_data": (state.filings if state.filings else {}),
"target_tickers": analysis_target_tickers,
"target_label": target_label_for_analysis,
}
try:
response_data = await call_agent(
client, f"{AGENT_ANALYSIS_URL}/analyze", json_payload=payload
)
state.analysis = response_data
logger.info(
f"Analysis Agent successful for '{response_data.get('target_label')}'."
)
except HTTPException as e:
state.errors.append(f"Analysis Agent Node failed: {e.detail}")
state.analysis = None
return state
async def language_agent_node(state: MarketBriefState) -> MarketBriefState:
async with httpx.AsyncClient() as client:
if not state.user_text or "Error:" in state.user_text:
state.errors.append("Language Agent: Skipping due to no valid user text.")
state.brief = (
"I could not understand your query or there was an earlier error."
)
return state
analysis_payload_for_llm: Dict[str, Any]
if state.analysis and isinstance(state.analysis, dict):
analysis_payload_for_llm = {
"target_label": state.analysis.get("target_label", "the portfolio"),
"current_allocation": state.analysis.get("current_allocation", 0.0),
"yesterday_allocation": state.analysis.get("yesterday_allocation", 0.0),
"allocation_change_percentage_points": state.analysis.get(
"allocation_change_percentage_points", 0.0
),
"earnings_surprises_for_target": state.analysis.get(
"earnings_surprises_for_target", []
),
}
else:
logger.warning(
"Language Agent: Analysis data is missing or not a dict. Using defaults."
)
state.warnings.append(
"Language Agent: Analysis data unavailable, brief will be general."
)
analysis_payload_for_llm = {
"target_label": "the portfolio (analysis data missing)",
"current_allocation": 0.0,
"yesterday_allocation": 0.0,
"allocation_change_percentage_points": 0.0,
"earnings_surprises_for_target": [],
}
payload = {
"user_query": state.user_text,
"analysis": analysis_payload_for_llm,
"retrieved_docs": state.retrieved_docs if state.retrieved_docs else [],
}
try:
response_data = await call_agent(
client, f"{AGENT_LANGUAGE_URL}/generate_brief", json_payload=payload
)
state.brief = response_data.get("brief")
logger.info(f"Language Agent successful. Brief: {state.brief[:70]}...")
except HTTPException as e:
state.errors.append(f"Language Agent Node failed: {e.detail}")
state.brief = "Sorry, I couldn't generate the brief at this time due to an internal error."
return state
async def tts_node(state: MarketBriefState) -> MarketBriefState:
brief_text_for_tts = state.brief
if state.errors and (
not state.brief
or "sorry" in state.brief.lower()
or "error" in state.brief.lower()
):
error_count = len(state.errors)
brief_text_for_tts = f"I encountered {error_count} error{'s' if error_count > 1 else ''} while processing your request. Please check the detailed report."
logger.warning(
f"TTS Node: Generating audio for error summary due to {error_count} errors in state."
)
elif not state.brief:
brief_text_for_tts = "The market brief could not be generated."
logger.warning("TTS Node: No brief text available from language agent.")
state.warnings.append("TTS Node: No brief content to synthesize.")
if not brief_text_for_tts:
state.audio_output = None
return state
async with httpx.AsyncClient() as client:
payload = {"text": brief_text_for_tts, "lang": "en"}
try:
response_data = await call_agent(
client, f"{AGENT_VOICE_URL}/tts", json_payload=payload
)
if "audio" in response_data and isinstance(response_data["audio"], str):
state.audio_output = bytes.fromhex(response_data["audio"])
logger.info("TTS successful. Audio bytes received.")
else:
state.errors.append(
f"TTS Agent response missing or invalid 'audio': {str(response_data)[:100]}"
)
state.audio_output = None
except HTTPException as e:
state.errors.append(f"TTS Node failed: {e.detail}")
state.audio_output = None
return state
def build_market_brief_graph():
builder = StateGraph(MarketBriefState)
builder.add_node("stt", stt_node)
builder.add_node("nlu", nlu_node)
builder.add_node("api_agent", api_agent_node)
builder.add_node("scraping_agent", scraping_agent_node)
builder.add_node("retriever_agent", retriever_agent_node)
builder.add_node("analysis_agent", analysis_agent_node)
builder.add_node("language_agent", language_agent_node)
builder.add_node("tts", tts_node)
builder.set_entry_point("stt")
builder.add_edge("stt", "nlu")
builder.add_edge("nlu", "api_agent")
builder.add_edge("api_agent", "scraping_agent")
builder.add_edge("scraping_agent", "retriever_agent")
builder.add_edge("retriever_agent", "analysis_agent")
builder.add_edge("analysis_agent", "language_agent")
builder.add_edge("language_agent", "tts")
builder.add_edge("tts", END)
return builder.compile()
graph = build_market_brief_graph()
@app.post("/market_brief")
async def market_brief(audio: UploadFile = File(...)):
logger.info("Received request to /market_brief")
if not audio.content_type or not audio.content_type.startswith("audio/"):
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail="Invalid file type.",
)
current_run_state = MarketBriefState()
try:
current_run_state.audio_input = await audio.read()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to read audio: {e}",
)
processed_state: MarketBriefState = current_run_state
try:
logger.info("Invoking LangGraph workflow...")
initial_state_dict = current_run_state.model_dump(exclude_none=True)
invocation_result = await graph.ainvoke(initial_state_dict)
if isinstance(invocation_result, dict):
processed_state = MarketBriefState(**invocation_result)
logger.info("LangGraph execution finished. State updated.")
else:
logger.error(
f"LangGraph ainvoke returned unexpected type: {type(invocation_result)}. Using partially updated state."
)
processed_state.errors.append(
f"Internal graph error: result type {type(invocation_result)}"
)
except HTTPException as e:
logger.error(
f"Graph execution stopped due to HTTPException from an agent: {e.detail}"
)
processed_state.errors.append(f"Agent call failed: {e.detail}")
except Exception as e:
error_msg = f"An unexpected error occurred during graph execution: {e}"
logger.error(error_msg, exc_info=True)
processed_state.errors.append(error_msg)
response_payload = {
"transcript": processed_state.user_text,
"brief": processed_state.brief,
"audio": (
processed_state.audio_output.hex() if processed_state.audio_output else None
),
"errors": processed_state.errors,
"warnings": processed_state.warnings,
"status": "success" if not processed_state.errors else "failed",
"message": "Market brief process completed."
+ (" With errors." if processed_state.errors else " Successfully."),
"nlu_detected": processed_state.nlu_results,
"analysis_details": processed_state.analysis,
}
logger.info(
f"Request finished. Status: {response_payload['status']}. Errors: {len(response_payload['errors'])}. Warnings: {len(response_payload['warnings'])}."
)
return response_payload