File size: 9,407 Bytes
3f43e82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, validator, Field
from typing import List, Dict, Any, Union
import google.generativeai as genai
import os
from dotenv import load_dotenv
import logging
import time
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Language Agent (Gemini Pro - Generalized)")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GEMINI_MODEL_NAME = os.getenv("GEMINI_MODEL_NAME", "gemini-1.5-flash-latest")
if not GOOGLE_API_KEY:
logger.warning("GOOGLE_API_KEY not found.")
else:
try:
genai.configure(api_key=GOOGLE_API_KEY)
logger.info(f"Google Generative AI configured for model {GEMINI_MODEL_NAME}.")
except Exception as e:
logger.error(f"Failed to configure Google Generative AI: {e}")
class EarningsSummaryLLM(BaseModel):
ticker: str
surprise_pct: float
class AnalysisDataLLM(BaseModel):
target_label: str = "the portfolio"
current_allocation: float = 0.0
yesterday_allocation: float = 0.0
allocation_change_percentage_points: float = 0.0
earnings_surprises: List[EarningsSummaryLLM] = Field(
default_factory=list, alias="earnings_surprises_for_target"
)
class BriefRequest(BaseModel):
user_query: str
analysis: AnalysisDataLLM
retrieved_docs: List[str] = Field(default_factory=list)
def construct_gemini_prompt(
user_query: str, analysis_data: AnalysisDataLLM, docs_context: str
) -> str:
alloc_change_str = ""
if analysis_data.allocation_change_percentage_points > 0.01:
alloc_change_str = f"up by {analysis_data.allocation_change_percentage_points:.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
elif analysis_data.allocation_change_percentage_points < -0.01:
alloc_change_str = f"down by {abs(analysis_data.allocation_change_percentage_points):.1f} percentage points from yesterday (approx. {analysis_data.yesterday_allocation*100:.0f}%)."
else:
alloc_change_str = f"remaining stable around {analysis_data.yesterday_allocation*100:.0f}% yesterday."
analysis_summary_str = f"For {analysis_data.target_label}, the current allocation is {analysis_data.current_allocation*100:.0f}% of AUM, {alloc_change_str}\n"
if analysis_data.earnings_surprises:
earnings_parts = []
for e in analysis_data.earnings_surprises:
direction = (
"beat estimates by" if e.surprise_pct >= 0 else "missed estimates by"
)
earnings_parts.append(f"{e.ticker} {direction} {abs(e.surprise_pct):.1f}%")
if earnings_parts:
analysis_summary_str += (
"Key earnings updates: " + ", ".join(earnings_parts) + "."
)
else:
analysis_summary_str += (
"No specific earnings surprises to highlight for this segment."
)
else:
analysis_summary_str += (
"No notable earnings surprises reported for this segment."
)
prompt = (
f"You are a professional financial assistant. Based on the user's query and the provided data, "
f"deliver a concise, spoken-style morning market brief for a portfolio manager. "
f"The brief should start with 'Good morning.'\n\n"
f"User Query: {user_query}\n\n"
f"Key Portfolio and Market Analysis:\n{analysis_summary_str}\n\n"
f"Relevant Filings Context (if any):\n{docs_context}\n\n"
f"If the user's query mentions a specific region or sector not covered by the 'Key Portfolio and Market Analysis', "
f"you can state that specific data for that exact query aspect was not available in the analysis provided. "
f"Mention any specific company earnings surprises from the analysis clearly (e.g., 'TSMC beat estimates by X%, Samsung missed by Y%')."
f"If there's information about broad regional sentiment or rising yields in the 'docs_context', incorporate it naturally. Otherwise, focus on the provided analysis."
)
return prompt
generation_config = genai.types.GenerationConfig(
temperature=0.6, max_output_tokens=1024
)
safety_settings = [
{"category": c, "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
for c in [
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
]
]
@app.post("/generate_brief")
async def generate_brief(request: BriefRequest):
if not GOOGLE_API_KEY:
raise HTTPException(status_code=500, detail="Google API Key not configured.")
logger.info(
f"Generating brief for query: '{request.user_query}' using Gemini model {GEMINI_MODEL_NAME}"
)
docs_context = (
"\n".join(request.retrieved_docs[:2])
if request.retrieved_docs
else "No relevant context from documents found."
)
full_prompt = construct_gemini_prompt(
user_query=request.user_query,
analysis_data=request.analysis,
docs_context=docs_context,
)
logger.debug(f"Full prompt for Gemini:\n{full_prompt}")
try:
model = genai.GenerativeModel(
model_name=GEMINI_MODEL_NAME,
generation_config=generation_config,
safety_settings=safety_settings,
)
max_retries = 1
retry_delay_seconds = 10
for attempt in range(max_retries + 1):
try:
response = await model.generate_content_async(full_prompt)
if not response.parts:
if (
response.prompt_feedback
and response.prompt_feedback.block_reason
):
block_reason_message = (
response.prompt_feedback.block_reason_message
or "Unknown safety block"
)
logger.error(
f"Gemini content generation blocked. Reason: {block_reason_message}"
)
raise HTTPException(
status_code=400,
detail=f"Content generation blocked: {block_reason_message}",
)
else:
logger.error("Gemini response has no parts (empty content).")
if attempt == max_retries:
raise HTTPException(
status_code=500,
detail="Gemini returned empty content after retries.",
)
else:
logger.warning(
f"Gemini returned empty content, attempt {attempt+1}/{max_retries+1}. Retrying..."
)
await asyncio.sleep(retry_delay_seconds)
continue
brief_text = response.text
logger.info("Gemini content generated successfully.")
return {"brief": brief_text}
except (
genai.types.generation_types.BlockedPromptException,
genai.types.generation_types.StopCandidateException,
) as sce_bpe:
logger.error(
f"Gemini generation issue on attempt {attempt+1}: {sce_bpe}"
)
raise HTTPException(
status_code=400, detail=f"Gemini generation issue: {sce_bpe}"
)
except Exception as e:
logger.error(
f"Error during Gemini generation on attempt {attempt+1}: {type(e).__name__} - {e}"
)
if (
"rate limit" in str(e).lower()
or "quota" in str(e).lower()
or "429" in str(e)
or "resource_exhausted" in str(e).lower()
):
if attempt < max_retries:
wait_time = retry_delay_seconds * (2**attempt)
logger.info(f"Rate limit likely. Retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
logger.error("Max retries reached for rate limit.")
raise HTTPException(
status_code=429,
detail=f"Gemini API rate limit/quota exceeded: {e}",
)
elif attempt < max_retries:
await asyncio.sleep(retry_delay_seconds)
continue
else:
raise HTTPException(
status_code=500,
detail=f"Failed to generate brief with Gemini: {e}",
)
raise HTTPException(
status_code=500, detail="Brief generation failed after all attempts."
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Critical error in /generate_brief: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Critical failure in brief generation: {e}"
)
|