elna / app /agent.py
David Chu
feat: reduce thinking budget to lower response time
5d93afc unverified
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
TextPart,
ToolReturnPart,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from app import models
from app.tools import dailymed, literature
class Context(BaseModel):
thoughts: list[str]
sources: dict[str, dict]
class Statement(BaseModel):
text: str
sources: list[str] | None = Field(
default=None, description="ID of the sources that support this statement."
)
def get_context(messages: list[ModelMessage]) -> Context:
thoughts: list[str] = []
sources: dict[str, dict] = {}
for message in messages:
if isinstance(message, ModelResponse):
for part in message.parts:
if isinstance(part, TextPart):
thoughts.append(part.content)
elif isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, ToolReturnPart) and part.tool_name in {
"search_medical_literature",
"find_drug_set_ids",
}:
for item in part.content:
sources[item["id"]] = item
return Context(thoughts=thoughts, sources=sources)
def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements:
context = get_context(ctx.messages)
statements = []
for statement in output:
sources = []
for source_id in statement.sources or []:
try:
sources.append(context.sources[source_id])
except KeyError as err:
raise ModelRetry(
f"Source ID '{source_id}' not found in literature."
) from err
statements.append({"text": statement.text, "sources": sources})
return models.Statements.model_validate(
{
"statements": statements,
"thoughts": "\n\n".join(context.thoughts),
}
)
model = GoogleModel("gemini-2.5-flash-preview-05-20")
settings = GoogleModelSettings(
# temperature=0.1,
google_thinking_config={"thinking_budget": 1024, "include_thoughts": True},
)
agent = Agent(
model=model,
name="elna",
model_settings=settings,
output_type=create_response,
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(),
tools=[
dailymed.find_drug_set_ids,
dailymed.find_drug_instruction,
literature.search_medical_literature,
],
)