raw
history blame
9.41 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, validator, Field
from typing import List, Dict, Any, Union
import google.generativeai as genai
import os
from dotenv import load_dotenv
import logging
import time
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Language Agent (Gemini Pro - Generalized)")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest")
if not GOOGLE_API_KEY:
logger.warning("GOOGLE_API_KEY not found.")
else:
try:
genai.configure(api_key=GOOGLE_API_KEY)
logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.")
except Exception as e:
logger.error(f"Failed to configure Google Generative AI: {e}")
class EarningsSummaryLLM(BaseModel):
ticker: str
surprise_pct: float
class AnalysisDataLLM(BaseModel):
target_label: str = "the portfolio"
current_allocation: float = 0.0
yesterday_allocation: float = 0.0
allocation_change_percentage_points: float = 0.0
earnings_surprises: List[EarningsSummaryLLM] = Field(
default_factory=list, alias="earnings_surprises_for_target"
)
class BriefRequest(BaseModel):
user_query: str
analysis: AnalysisDataLLM
retrieved_docs: List[str] = Field(default_factory=list)
def construct_gemini_prompt(
user_query: str, analysis_data: AnalysisDataLLM, docs_context: str
) -> str:
alloc_change_str = ""
if analysis_data.allocation_change_percentage_points > 0.01:
alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
elif analysis_data.allocation_change_percentage_points < -0.01:
alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
else:
alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday."
analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n"
if analysis_data.earnings_surprises:
earnings_parts = []
for e in analysis_data.earnings_surprises:
direction = (
"beat estimates by" if e.surprise_pct >= 0 else "missed estimates by"
)
earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%")
if earnings_parts:
analysis_summary_str += (
"Key earnings updates: " + ", ".join(earnings_parts) + "."
)
else:
analysis_summary_str += (
"No specific earnings surprises to highlight for this segment."
)
else:
analysis_summary_str += (
"No notable earnings surprises reported for this segment."
)
prompt = (
f"You are a professional financial assistant. Based on the user's query and the provided data, "
f"deliver a concise, spoken-style morning market brief for a portfolio manager. "
f"The brief should start with 'Good morning.'\n\n"
f"User Query: {user_query}\n\n"
f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n"
f"Relevant Filings Context (if any):\n{docs_context}\n\n"
f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', "
f"you can state that specific data for that exact query aspect was not available in the analysis provided. "
f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')."
f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis."
)
return prompt
generation_config = genai.types.GenerationConfig(
temperature=0.6, max_output_tokens=1024
)
safety_settings = [
{"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
for c in [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
]
]
@app.post("/generate_brief")
async def generate_brief(request: BriefRequest):
if not GOOGLE_API_KEY:
raise HTTPException(status_code=500, detail="Google API Key not configured.")
logger.info(
f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}"
)
docs_context = (
"\n".join(request.retrieved_docs[:2])
if request.retrieved_docs
else "No relevant context from documents found."
)
full_prompt = construct_gemini_prompt(
user_query=request.user_query,
analysis_data=request.analysis,
docs_context=docs_context,
)
logger.debug(f"Full prompt for Gemini:\n{full_prompt}")
try:
model = genai.GenerativeModel(
model_name=GEMINI_MODEL_NAME,
generation_config=generation_config,
safety_settings=safety_settings,
)
max_retries = 1
retry_delay_seconds = 10
for attempt in range(max_retries + 1):
try:
response = await model.generate_content_async(full_prompt)
if not response.parts:
if (
response.prompt_feedback
and response.prompt_feedback.block_reason
):
block_reason_message = (
response.prompt_feedback.block_reason_message
or "Unknown safety block"
)
logger.error(
f"Gemini content generation blocked. Reason: {block_reason_message}"
)
raise HTTPException(
status_code=400,
detail=f"Content generation blocked: {block_reason_message}",
)
else:
logger.error("Gemini response has no parts (empty content).")
if attempt == max_retries:
raise HTTPException(
status_code=500,
detail="Gemini returned empty content after retries.",
)
else:
logger.warning(
f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..."
)
await asyncio.sleep(retry_delay_seconds)
continue
brief_text = response.text
logger.info("Gemini content generated successfully.")
return {"brief": brief_text}
except (
genai.types.generation_types.BlockedPromptException,
genai.types.generation_types.StopCandidateException,
) as sce_bpe:
logger.error(
f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}"
)
raise HTTPException(
status_code=400, detail=f"Gemini generation issue: {sce_bpe}"
)
except Exception as e:
logger.error(
f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}"
)
if (
"rate limit" in str(e).lower()
or "quota" in str(e).lower()
or "429" in str(e)
or "resource_exhausted" in str(e).lower()
):
if attempt < max_retries:
wait_time = retry_delay_seconds * (2**attempt)
logger.info(f"Rate limit likely. Retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
logger.error("Max retries reached for rate limit.")
raise HTTPException(
status_code=429,
detail=f"Gemini API rate limit/quota exceeded: {e}",
)
elif attempt < max_retries:
await asyncio.sleep(retry_delay_seconds)
continue
else:
raise HTTPException(
status_code=500,
detail=f"Failed to generate brief with Gemini: {e}",
)
raise HTTPException(
status_code=500, detail="Brief generation failed after all attempts."
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Critical error in /generate_brief: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Critical failure in brief generation: {e}"
)