from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, validator, Field
from typing import List, Dict, Any, Union
import google.generativeai as genai
import os
from dotenv import load_dotenv
import logging
import time
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Language Agent (Gemini Pro - Generalized)")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest")
if not GOOGLE_API_KEY:
logger.warning("GOOGLE_API_KEY not found.")
else:
try:
genai.configure(api_key=GOOGLE_API_KEY)
logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.")
except Exception as e:
logger.error(f"Failed to configure Google Generative AI: {e}")
class EarningsSummaryLLM(BaseModel):
ticker: str
surprise_pct: float
class AnalysisDataLLM(BaseModel):
target_label: str = "the portfolio"
current_allocation: float = 0.0
yesterday_allocation: float = 0.0
allocation_change_percentage_points: float = 0.0
earnings_surprises: List[EarningsSummaryLLM] = Field(
default_factory=list, alias="earnings_surprises_for_target"
)
class BriefRequest(BaseModel):
user_query: str
analysis: AnalysisDataLLM
retrieved_docs: List[str] = Field(default_factory=list)
def construct_gemini_prompt(
user_query: str, analysis_data: AnalysisDataLLM, docs_context: str
) -> str:
alloc_change_str = ""
if analysis_data.allocation_change_percentage_points > 0.01:
alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
elif analysis_data.allocation_change_percentage_points < -0.01:
alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
else:
alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday."
analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n"
if analysis_data.earnings_surprises:
earnings_parts = []
for e in analysis_data.earnings_surprises:
direction = (
"beat estimates by" if e.surprise_pct >= 0 else "missed estimates by"
)
earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%")
if earnings_parts:
analysis_summary_str += (
"Key earnings updates: " + ", ".join(earnings_parts) + "."
)
else:
analysis_summary_str += (
"No specific earnings surprises to highlight for this segment."
)
else:
analysis_summary_str += (
"No notable earnings surprises reported for this segment."
)
prompt = (
f"You are a professional financial assistant. Based on the user's query and the provided data, "
f"deliver a concise, spoken-style morning market brief for a portfolio manager. "
f"The brief should start with 'Good morning.'\n\n"
f"User Query: {user_query}\n\n"
f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n"
f"Relevant Filings Context (if any):\n{docs_context}\n\n"
f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', "
f"you can state that specific data for that exact query aspect was not available in the analysis provided. "
f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')."
f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis."
)
return prompt
generation_config = genai.types.GenerationConfig(
temperature=0.6, max_output_tokens=1024
)
safety_settings = [
{"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
for c in [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
]
]
@app.post("/generate_brief")
async def generate_brief(request: BriefRequest):
if not GOOGLE_API_KEY:
raise HTTPException(status_code=500, detail="Google API Key not configured.")
logger.info(
f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}"
)
docs_context = (
"\n".join(request.retrieved_docs[:2])
if request.retrieved_docs
else "No relevant context from documents found."
)
full_prompt = construct_gemini_prompt(
user_query=request.user_query,
analysis_data=request.analysis,
docs_context=docs_context,
)
logger.debug(f"Full prompt for Gemini:\n{full_prompt}")
try:
model = genai.GenerativeModel(
model_name=GEMINI_MODEL_NAME,
generation_config=generation_config,
safety_settings=safety_settings,
)
max_retries = 1
retry_delay_seconds = 10
for attempt in range(max_retries + 1):
try:
response = await model.generate_content_async(full_prompt)
if not response.parts:
if (
response.prompt_feedback
and response.prompt_feedback.block_reason
):
block_reason_message = (
response.prompt_feedback.block_reason_message
or "Unknown safety block"
)
logger.error(
f"Gemini content generation blocked. Reason: {block_reason_message}"
)
raise HTTPException(
status_code=400,
detail=f"Content generation blocked: {block_reason_message}",
)
else:
logger.error("Gemini response has no parts (empty content).")
if attempt == max_retries:
raise HTTPException(
status_code=500,
detail="Gemini returned empty content after retries.",
)
else:
logger.warning(
f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..."
)
await asyncio.sleep(retry_delay_seconds)
continue
brief_text = response.text
logger.info("Gemini content generated successfully.")
return {"brief": brief_text}
except (
genai.types.generation_types.BlockedPromptException,
genai.types.generation_types.StopCandidateException,
) as sce_bpe:
logger.error(
f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}"
)
raise HTTPException(
status_code=400, detail=f"Gemini generation issue: {sce_bpe}"
)
except Exception as e:
logger.error(
f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}"
)
if (
"rate limit" in str(e).lower()
or "quota" in str(e).lower()
or "429" in str(e)
or "resource_exhausted" in str(e).lower()
):
if attempt < max_retries:
wait_time = retry_delay_seconds * (2**attempt)
logger.info(f"Rate limit likely. Retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
logger.error("Max retries reached for rate limit.")
raise HTTPException(
status_code=429,
detail=f"Gemini API rate limit/quota exceeded: {e}",
)
elif attempt < max_retries:
await asyncio.sleep(retry_delay_seconds)
continue
else:
raise HTTPException(
status_code=500,
detail=f"Failed to generate brief with Gemini: {e}",
)
raise HTTPException(
status_code=500, detail="Brief generation failed after all attempts."
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Critical error in /generate_brief: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Critical failure in brief generation: {e}"
)