|
from pathlib import Path |
|
|
|
from pydantic import BaseModel, Field |
|
from pydantic_ai import Agent, ModelRetry, RunContext |
|
from pydantic_ai.messages import ( |
|
ModelMessage, |
|
ModelRequest, |
|
ModelResponse, |
|
TextPart, |
|
ToolReturnPart, |
|
) |
|
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings |
|
|
|
from app import models |
|
from app.tools import dailymed, literature |
|
|
|
|
|
class Context(BaseModel): |
|
thoughts: list[str] |
|
sources: dict[str, dict] |
|
|
|
|
|
class Statement(BaseModel): |
|
text: str |
|
sources: list[str] | None = Field( |
|
default=None, description="ID of the sources that support this statement." |
|
) |
|
|
|
|
|
def get_context(messages: list[ModelMessage]) -> Context: |
|
thoughts: list[str] = [] |
|
sources: dict[str, dict] = {} |
|
|
|
for message in messages: |
|
if isinstance(message, ModelResponse): |
|
for part in message.parts: |
|
if isinstance(part, TextPart): |
|
thoughts.append(part.content) |
|
elif isinstance(message, ModelRequest): |
|
for part in message.parts: |
|
if isinstance(part, ToolReturnPart) and part.tool_name in { |
|
"search_medical_literature", |
|
"find_drug_set_ids", |
|
}: |
|
for item in part.content: |
|
sources[item["id"]] = item |
|
|
|
return Context(thoughts=thoughts, sources=sources) |
|
|
|
|
|
def create_response(ctx: RunContext, output: list[Statement]) -> models.Statements: |
|
context = get_context(ctx.messages) |
|
|
|
statements = [] |
|
for statement in output: |
|
sources = [] |
|
for source_id in statement.sources or []: |
|
try: |
|
sources.append(context.sources[source_id]) |
|
except KeyError as err: |
|
raise ModelRetry( |
|
f"Source ID '{source_id}' not found in literature." |
|
) from err |
|
statements.append({"text": statement.text, "sources": sources}) |
|
|
|
return models.Statements.model_validate( |
|
{ |
|
"statements": statements, |
|
"thoughts": "\n\n".join(context.thoughts), |
|
} |
|
) |
|
|
|
|
|
model = GoogleModel("gemini-2.5-flash-preview-05-20") |
|
settings = GoogleModelSettings( |
|
|
|
google_thinking_config={"thinking_budget": 1024, "include_thoughts": True}, |
|
) |
|
agent = Agent( |
|
model=model, |
|
name="elna", |
|
model_settings=settings, |
|
output_type=create_response, |
|
system_prompt=(Path(__file__).parent / "system_instruction.txt").read_text(), |
|
tools=[ |
|
dailymed.find_drug_set_ids, |
|
dailymed.find_drug_instruction, |
|
literature.search_medical_literature, |
|
], |
|
) |
|
|