from fastapi import FastAPI, HTTPException from pydantic import BaseModel, validator, Field from typing import List, Dict, Any, Union import google.generativeai as genai import os from dotenv import load_dotenv import logging import time load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Language Agent (Gemini Pro - Generalized)") GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest") if not GOOGLE_API_KEY: logger.warning("GOOGLE_API_KEY not found.") else: try: genai.configure(api_key=GOOGLE_API_KEY) logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.") except Exception as e: logger.error(f"Failed to configure Google Generative AI: {e}") class EarningsSummaryLLM(BaseModel): ticker: str surprise_pct: float class AnalysisDataLLM(BaseModel): target_label: str = "the portfolio" current_allocation: float = 0.0 yesterday_allocation: float = 0.0 allocation_change_percentage_points: float = 0.0 earnings_surprises: List[EarningsSummaryLLM] = Field( default_factory=list, alias="earnings_surprises_for_target" ) class BriefRequest(BaseModel): user_query: str analysis: AnalysisDataLLM retrieved_docs: List[str] = Field(default_factory=list) def construct_gemini_prompt( user_query: str, analysis_data: AnalysisDataLLM, docs_context: str ) -> str: alloc_change_str = "" if analysis_data.allocation_change_percentage_points > 0.01: alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)." elif analysis_data.allocation_change_percentage_points < -0.01: alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)." else: alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday." analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n" if analysis_data.earnings_surprises: earnings_parts = [] for e in analysis_data.earnings_surprises: direction = ( "beat estimates by" if e.surprise_pct >= 0 else "missed estimates by" ) earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%") if earnings_parts: analysis_summary_str += ( "Key earnings updates: " + ", ".join(earnings_parts) + "." ) else: analysis_summary_str += ( "No specific earnings surprises to highlight for this segment." ) else: analysis_summary_str += ( "No notable earnings surprises reported for this segment." ) prompt = ( f"You are a professional financial assistant. Based on the user's query and the provided data, " f"deliver a concise, spoken-style morning market brief for a portfolio manager. " f"The brief should start with 'Good morning.'\n\n" f"User Query: {user_query}\n\n" f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n" f"Relevant Filings Context (if any):\n{docs_context}\n\n" f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', " f"you can state that specific data for that exact query aspect was not available in the analysis provided. " f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')." f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis." ) return prompt generation_config = genai.types.GenerationConfig( temperature=0.6, max_output_tokens=1024 ) safety_settings = [ {"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"} for c in [ "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_DANGEROUS_CONTENT", ] ] @app.post("/generate_brief") async def generate_brief(request: BriefRequest): if not GOOGLE_API_KEY: raise HTTPException(status_code=500, detail="Google API Key not configured.") logger.info( f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}" ) docs_context = ( "\n".join(request.retrieved_docs[:2]) if request.retrieved_docs else "No relevant context from documents found." ) full_prompt = construct_gemini_prompt( user_query=request.user_query, analysis_data=request.analysis, docs_context=docs_context, ) logger.debug(f"Full prompt for Gemini:\n{full_prompt}") try: model = genai.GenerativeModel( model_name=GEMINI_MODEL_NAME, generation_config=generation_config, safety_settings=safety_settings, ) max_retries = 1 retry_delay_seconds = 10 for attempt in range(max_retries + 1): try: response = await model.generate_content_async(full_prompt) if not response.parts: if ( response.prompt_feedback and response.prompt_feedback.block_reason ): block_reason_message = ( response.prompt_feedback.block_reason_message or "Unknown safety block" ) logger.error( f"Gemini content generation blocked. Reason: {block_reason_message}" ) raise HTTPException( status_code=400, detail=f"Content generation blocked: {block_reason_message}", ) else: logger.error("Gemini response has no parts (empty content).") if attempt == max_retries: raise HTTPException( status_code=500, detail="Gemini returned empty content after retries.", ) else: logger.warning( f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..." ) await asyncio.sleep(retry_delay_seconds) continue brief_text = response.text logger.info("Gemini content generated successfully.") return {"brief": brief_text} except ( genai.types.generation_types.BlockedPromptException, genai.types.generation_types.StopCandidateException, ) as sce_bpe: logger.error( f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}" ) raise HTTPException( status_code=400, detail=f"Gemini generation issue: {sce_bpe}" ) except Exception as e: logger.error( f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}" ) if ( "rate limit" in str(e).lower() or "quota" in str(e).lower() or "429" in str(e) or "resource_exhausted" in str(e).lower() ): if attempt < max_retries: wait_time = retry_delay_seconds * (2**attempt) logger.info(f"Rate limit likely. Retrying in {wait_time}s...") await asyncio.sleep(wait_time) continue else: logger.error("Max retries reached for rate limit.") raise HTTPException( status_code=429, detail=f"Gemini API rate limit/quota exceeded: {e}", ) elif attempt < max_retries: await asyncio.sleep(retry_delay_seconds) continue else: raise HTTPException( status_code=500, detail=f"Failed to generate brief with Gemini: {e}", ) raise HTTPException( status_code=500, detail="Brief generation failed after all attempts." ) except HTTPException: raise except Exception as e: logger.error(f"Critical error in /generate_brief: {e}", exc_info=True) raise HTTPException( status_code=500, detail=f"Critical failure in brief generation: {e}" )